From 9649a2eaee378ac40fa024b05817b74d826051c3 Mon Sep 17 00:00:00 2001 From: Sama Setty Date: Fri, 29 May 2026 13:16:10 -0700 Subject: [PATCH 01/15] Retry streamable HTTP initialize failures --- .../src/bin/test_streamable_http_server.rs | 106 +++++---- .../rmcp-client/src/http_client_adapter.rs | 11 + codex-rs/rmcp-client/src/rmcp_client.rs | 206 +++++++++++++++--- .../tests/streamable_http_recovery.rs | 183 ++++++++++++++++ .../tests/streamable_http_test_support.rs | 30 ++- 5 files changed, 467 insertions(+), 69 deletions(-) diff --git a/codex-rs/rmcp-client/src/bin/test_streamable_http_server.rs b/codex-rs/rmcp-client/src/bin/test_streamable_http_server.rs index 2384d394736..8004aa5dfba 100644 --- a/codex-rs/rmcp-client/src/bin/test_streamable_http_server.rs +++ b/codex-rs/rmcp-client/src/bin/test_streamable_http_server.rs @@ -63,11 +63,13 @@ struct TestToolServer { const MEMO_URI: &str = "memo://codex/example-note"; const MEMO_CONTENT: &str = "This is a sample MCP resource served by the rmcp test server."; const MCP_SESSION_ID_HEADER: &str = "mcp-session-id"; +const INITIALIZE_FAILURE_CONTROL_PATH: &str = "/test/control/initialize-failure"; const SESSION_POST_FAILURE_CONTROL_PATH: &str = "/test/control/session-post-failure"; #[derive(Clone, Default)] -struct SessionFailureState { - armed_failure: Arc>>, +struct FailureState { + initialize_failure: Arc>>, + session_post_failure: Arc>>, } #[derive(Clone, Debug)] @@ -79,7 +81,7 @@ struct ArmedFailure { } #[derive(Debug, Deserialize)] -struct ArmSessionPostFailureRequest { +struct ArmFailureRequest { status: u16, remaining: usize, /// Raw `WWW-Authenticate` challenge header field values to add to the failure. @@ -97,7 +99,7 @@ struct EchoArgs { #[tokio::main] async fn main() -> Result<(), Box> { let bind_addr = parse_bind_addr()?; - let session_failure_state = SessionFailureState::default(); + let failure_state = FailureState::default(); const MAX_BIND_RETRIES: u32 = 20; const BIND_RETRY_DELAY: Duration = Duration::from_millis(50); @@ -125,6 +127,7 @@ async fn main() -> Result<(), Box> { eprintln!("starting rmcp streamable http test server on http://{actual_bind_addr}/mcp"); let router = Router::new() + .route(INITIALIZE_FAILURE_CONTROL_PATH, post(arm_initialize_failure)) .route( SESSION_POST_FAILURE_CONTROL_PATH, post(arm_session_post_failure), @@ -162,10 +165,10 @@ async fn main() -> Result<(), Box> { ), ) .layer(middleware::from_fn_with_state( - session_failure_state.clone(), - fail_session_post_when_armed, + failure_state.clone(), + fail_post_when_armed, )) - .with_state(session_failure_state); + .with_state(failure_state); let router = if let Ok(token) = std::env::var("MCP_EXPECT_BEARER") { let expected = Arc::new(format!("Bearer {token}")); @@ -404,8 +407,22 @@ async fn require_bearer( } async fn arm_session_post_failure( - State(state): State, - Json(request): Json, + State(state): State, + Json(request): Json, +) -> Result { + arm_failure(&state.session_post_failure, request).await +} + +async fn arm_initialize_failure( + State(state): State, + Json(request): Json, +) -> Result { + arm_failure(&state.initialize_failure, request).await +} + +async fn arm_failure( + armed_failure: &Arc>>, + request: ArmFailureRequest, ) -> Result { let status = StatusCode::from_u16(request.status).map_err(|_| StatusCode::BAD_REQUEST)?; let www_authenticate_headers = request @@ -413,7 +430,7 @@ async fn arm_session_post_failure( .into_iter() .map(|value| HeaderValue::from_str(&value).map_err(|_| StatusCode::BAD_REQUEST)) .collect::, _>>()?; - let armed_failure = if request.remaining == 0 { + let failure = if request.remaining == 0 { None } else { Some(ArmedFailure { @@ -422,45 +439,56 @@ async fn arm_session_post_failure( www_authenticate_headers, }) }; - *state.armed_failure.lock().await = armed_failure; + *armed_failure.lock().await = failure; Ok(StatusCode::NO_CONTENT) } -async fn fail_session_post_when_armed( - State(state): State, +async fn fail_post_when_armed( + State(state): State, request: Request, next: Next, ) -> Response { - if request.uri().path() != "/mcp" - || request.method() != Method::POST - || !request.headers().contains_key(MCP_SESSION_ID_HEADER) - { + if request.uri().path() != "/mcp" || request.method() != Method::POST { return next.run(request).await; } - { - let mut armed_failure = state.armed_failure.lock().await; - if let Some(failure) = armed_failure.as_mut() - && failure.remaining > 0 - { - failure.remaining -= 1; - let status = failure.status; - let www_authenticate_headers = failure.www_authenticate_headers.clone(); - if failure.remaining == 0 { - *armed_failure = None; - } - let mut response = Response::new(Body::from(format!( - "forced session failure with status {status}" - ))); - *response.status_mut() = status; - for www_authenticate_header in www_authenticate_headers { - response - .headers_mut() - .append(WWW_AUTHENTICATE, www_authenticate_header); - } - return response; - } + let (armed_failure, label) = if request.headers().contains_key(MCP_SESSION_ID_HEADER) { + (&state.session_post_failure, "session") + } else { + (&state.initialize_failure, "initialize") + }; + + if let Some(response) = consume_failure(armed_failure, label).await { + return response; } next.run(request).await } + +async fn consume_failure( + armed_failure: &Arc>>, + label: &str, +) -> Option { + let mut armed_failure = armed_failure.lock().await; + let failure = armed_failure.as_mut()?; + if failure.remaining == 0 { + return None; + } + + failure.remaining -= 1; + let status = failure.status; + let www_authenticate_headers = failure.www_authenticate_headers.clone(); + if failure.remaining == 0 { + *armed_failure = None; + } + let mut response = Response::new(Body::from(format!( + "forced {label} failure with status {status}" + ))); + *response.status_mut() = status; + for www_authenticate_header in www_authenticate_headers { + response + .headers_mut() + .append(WWW_AUTHENTICATE, www_authenticate_header); + } + Some(response) +} diff --git a/codex-rs/rmcp-client/src/http_client_adapter.rs b/codex-rs/rmcp-client/src/http_client_adapter.rs index 6f98f789205..23ec8fca369 100644 --- a/codex-rs/rmcp-client/src/http_client_adapter.rs +++ b/codex-rs/rmcp-client/src/http_client_adapter.rs @@ -58,6 +58,8 @@ pub(crate) struct StreamableHttpClientAdapter { pub(crate) enum StreamableHttpClientAdapterError { #[error("streamable HTTP session expired with 404 Not Found")] SessionExpired404, + #[error("streamable HTTP request returned retryable HTTP {0}")] + RetryableHttpStatus(u16), #[error(transparent)] HttpRequest(#[from] ExecServerError), #[error("invalid HTTP header: {0}")] @@ -182,6 +184,11 @@ impl StreamableHttpClient for StreamableHttpClientAdapter { ) { return Ok(StreamableHttpPostResponse::Accepted); } + if is_retryable_http_status(response.status) { + return Err(StreamableHttpError::Client( + StreamableHttpClientAdapterError::RetryableHttpStatus(response.status), + )); + } let content_type = response_header(&response.headers, CONTENT_TYPE); let session_id = response_header(&response.headers, HEADER_SESSION_ID); @@ -463,6 +470,10 @@ fn status_is_success(status: u16) -> bool { StatusCode::from_u16(status).is_ok_and(|status| status.is_success()) } +fn is_retryable_http_status(status: u16) -> bool { + matches!(status, 408 | 429 | 500 | 502 | 503 | 504) +} + async fn collect_body( body_stream: &mut HttpResponseBodyStream, ) -> std::result::Result, StreamableHttpError> { diff --git a/codex-rs/rmcp-client/src/rmcp_client.rs b/codex-rs/rmcp-client/src/rmcp_client.rs index 90b09d724c3..09928fe82f5 100644 --- a/codex-rs/rmcp-client/src/rmcp_client.rs +++ b/codex-rs/rmcp-client/src/rmcp_client.rs @@ -14,6 +14,7 @@ use anyhow::anyhow; use codex_api::SharedAuthProvider; use codex_client::maybe_build_rustls_client_config_with_custom_ca; use codex_config::types::McpServerEnvVar; +use codex_exec_server::ExecServerError; use codex_exec_server::HttpClient; use futures::FutureExt; use futures::future::BoxFuture; @@ -74,6 +75,8 @@ use crate::utils::apply_default_headers; use crate::utils::build_default_headers; use codex_config::types::OAuthCredentialsStoreMode; +const STREAMABLE_HTTP_RETRY_DELAYS_MS: [u64; 2] = [250, 1_000]; + enum PendingTransport { InProcess { transport: tokio::io::DuplexStream, @@ -90,6 +93,16 @@ enum PendingTransport { }, } +impl PendingTransport { + fn is_streamable_http(&self) -> bool { + matches!( + self, + PendingTransport::StreamableHttp { .. } + | PendingTransport::StreamableHttpWithOAuth { .. } + ) + } +} + enum ClientState { Connecting { transport: Option, @@ -223,6 +236,13 @@ enum ClientOperationError { Timeout { label: String, duration: Duration }, } +#[derive(Debug, thiserror::Error)] +#[error("handshaking with MCP server failed: {source}")] +struct HandshakeError { + #[source] + source: rmcp::service::ClientInitializeError, +} + pub type Elicitation = CreateElicitationRequestParams; #[derive(Debug, Clone, PartialEq, Serialize, Deserialize)] @@ -396,9 +416,13 @@ impl RmcpClient { } }; - let (service, oauth_persistor) = - Self::connect_pending_transport(pending_transport, client_service.clone(), timeout) - .await?; + let (service, oauth_persistor) = self + .connect_pending_transport_with_initialize_retries( + pending_transport, + client_service.clone(), + timeout, + ) + .await?; let initialize_result_rmcp = service .peer() @@ -849,15 +873,63 @@ impl RmcpClient { Some(duration) => time::timeout(duration, transport) .await .map_err(|_| anyhow!("timed out handshaking with MCP server after {duration:?}"))? - .map_err(|err| anyhow!("handshaking with MCP server failed: {err}"))?, + .map_err(|source| HandshakeError { source })?, None => transport .await - .map_err(|err| anyhow!("handshaking with MCP server failed: {err}"))?, + .map_err(|source| HandshakeError { source })?, }; Ok((Arc::new(service), oauth_persistor)) } + async fn connect_pending_transport_with_initialize_retries( + &self, + initial_transport: PendingTransport, + client_service: ElicitationClientService, + timeout: Option, + ) -> Result<( + Arc>, + Option, + )> { + let should_retry = initial_transport.is_streamable_http(); + let mut pending_transport = Some(initial_transport); + + let retry_schedule = STREAMABLE_HTTP_RETRY_DELAYS_MS + .iter() + .copied() + .map(Some) + .chain(std::iter::once(None)); + + for (attempt, retry_delay_ms) in retry_schedule.enumerate() { + let transport = match pending_transport.take() { + Some(transport) => transport, + None => Self::create_pending_transport(&self.transport_recipe).await?, + }; + + match Self::connect_pending_transport(transport, client_service.clone(), timeout).await + { + Ok(result) => return Ok(result), + Err(error) if should_retry && Self::is_retryable_initialize_error(&error) => { + let Some(retry_delay_ms) = retry_delay_ms else { + return Err(error); + }; + let delay = Duration::from_millis(retry_delay_ms); + warn!( + attempt = attempt + 1, + max_attempts = STREAMABLE_HTTP_RETRY_DELAYS_MS.len() + 1, + delay_ms = delay.as_millis(), + error = %error, + "streamable HTTP MCP initialize failed with a retryable error; retrying" + ); + time::sleep(delay).await; + } + Err(error) => return Err(error), + } + } + + unreachable!("initialize retry loop should return on success or final error") + } + async fn run_service_operation( &self, label: &str, @@ -868,31 +940,45 @@ impl RmcpClient { F: Fn(Arc>) -> Fut, Fut: std::future::Future>, { - let service = self.service().await?; - match Self::run_service_operation_once( - Arc::clone(&service), - label, - timeout, - self.elicitation_pause_state.clone(), - &operation, - ) - .await - { - Ok(result) => Ok(result), - Err(error) if Self::is_session_expired_404(&error) => { - self.reinitialize_after_session_expiry(&service).await?; - let recovered_service = self.service().await?; - Self::run_service_operation_once( - recovered_service, - label, - timeout, - self.elicitation_pause_state.clone(), - &operation, - ) - .await - .map_err(Into::into) + let mut session_recovery_attempted = false; + let mut retry_attempt = 0; + + loop { + let service = self.service().await?; + match Self::run_service_operation_once( + Arc::clone(&service), + label, + timeout, + self.elicitation_pause_state.clone(), + &operation, + ) + .await + { + Ok(result) => return Ok(result), + Err(error) + if !session_recovery_attempted && Self::is_session_expired_404(&error) => + { + session_recovery_attempted = true; + self.reinitialize_after_session_expiry(&service).await?; + } + Err(error) + if Self::should_retry_tools_list_operation(label, retry_attempt, &error) => + { + let delay = + Duration::from_millis(STREAMABLE_HTTP_RETRY_DELAYS_MS[retry_attempt]); + retry_attempt += 1; + warn!( + label, + attempt = retry_attempt, + max_attempts = STREAMABLE_HTTP_RETRY_DELAYS_MS.len() + 1, + delay_ms = delay.as_millis(), + error = %error, + "MCP service operation failed with a retryable error; retrying" + ); + time::sleep(delay).await; + } + Err(error) => return Err(error.into()), } - Err(error) => Err(error.into()), } } @@ -941,6 +1027,68 @@ impl RmcpClient { }) } + fn should_retry_tools_list_operation( + label: &str, + retry_attempt: usize, + error: &ClientOperationError, + ) -> bool { + label == "tools/list" + && retry_attempt < STREAMABLE_HTTP_RETRY_DELAYS_MS.len() + && Self::is_retryable_service_operation_error(error) + } + + fn is_retryable_service_operation_error(error: &ClientOperationError) -> bool { + let ClientOperationError::Service(rmcp::service::ServiceError::TransportSend(error)) = + error + else { + return false; + }; + + error + .error + .downcast_ref::>() + .is_some_and(Self::is_retryable_streamable_http_error) + } + + fn is_retryable_initialize_error(error: &anyhow::Error) -> bool { + error.chain().any(|source| { + source + .downcast_ref::() + .is_some_and(|error| Self::is_retryable_client_initialize_error(&error.source)) + || source + .downcast_ref::() + .is_some_and(Self::is_retryable_client_initialize_error) + }) + } + + fn is_retryable_client_initialize_error(error: &rmcp::service::ClientInitializeError) -> bool { + match error { + rmcp::service::ClientInitializeError::TransportError { error, context } + if context.as_ref() == "send initialize request" => + { + error + .error + .downcast_ref::>() + .is_some_and(Self::is_retryable_streamable_http_error) + } + _ => false, + } + } + + fn is_retryable_streamable_http_error( + error: &StreamableHttpError, + ) -> bool { + matches!( + error, + StreamableHttpError::Client( + StreamableHttpClientAdapterError::RetryableHttpStatus(_) + | StreamableHttpClientAdapterError::HttpRequest(ExecServerError::HttpRequest( + _ + )) + ) + ) + } + async fn reinitialize_after_session_expiry( &self, failed_service: &Arc>, diff --git a/codex-rs/rmcp-client/tests/streamable_http_recovery.rs b/codex-rs/rmcp-client/tests/streamable_http_recovery.rs index 087d3d00df6..fd2a19072e3 100644 --- a/codex-rs/rmcp-client/tests/streamable_http_recovery.rs +++ b/codex-rs/rmcp-client/tests/streamable_http_recovery.rs @@ -1,13 +1,196 @@ mod streamable_http_test_support; +use std::sync::Arc; +use std::sync::atomic::AtomicUsize; +use std::sync::atomic::Ordering; +use std::time::Duration; + +use codex_exec_server::Environment; +use codex_exec_server::ExecServerError; +use codex_exec_server::HttpClient; +use codex_exec_server::HttpRequestParams; +use codex_exec_server::HttpRequestResponse; +use codex_exec_server::HttpResponseBodyStream; +use futures::FutureExt as _; +use futures::future::BoxFuture; use pretty_assertions::assert_eq; +use serde_json::Value; +use streamable_http_test_support::arm_initialize_failure; use streamable_http_test_support::arm_session_post_failure; use streamable_http_test_support::call_echo_tool; use streamable_http_test_support::create_client; +use streamable_http_test_support::create_client_with_http_client; use streamable_http_test_support::expected_echo_result; use streamable_http_test_support::spawn_streamable_http_server; +#[derive(Clone)] +struct FailFirstMethodHttpClient { + inner: Arc, + method: &'static str, + failures_remaining: Arc, + matching_post_attempts: Arc, +} + +impl FailFirstMethodHttpClient { + fn new(inner: Arc, method: &'static str) -> Self { + Self { + inner, + method, + failures_remaining: Arc::new(AtomicUsize::new(1)), + matching_post_attempts: Arc::new(AtomicUsize::new(0)), + } + } + + fn matching_post_attempts(&self) -> usize { + self.matching_post_attempts.load(Ordering::SeqCst) + } +} + +impl HttpClient for FailFirstMethodHttpClient { + fn http_request( + &self, + params: HttpRequestParams, + ) -> BoxFuture<'_, Result> { + self.inner.http_request(params) + } + + fn http_request_stream( + &self, + params: HttpRequestParams, + ) -> BoxFuture<'_, Result<(HttpRequestResponse, HttpResponseBodyStream), ExecServerError>> { + let inner = Arc::clone(&self.inner); + let method = self.method; + let failures_remaining = Arc::clone(&self.failures_remaining); + let matching_post_attempts = Arc::clone(&self.matching_post_attempts); + + async move { + if is_json_rpc_method(¶ms, method) { + matching_post_attempts.fetch_add(1, Ordering::SeqCst); + if failures_remaining + .fetch_update(Ordering::SeqCst, Ordering::SeqCst, |remaining| { + remaining.checked_sub(1) + }) + .is_ok() + { + return Err(ExecServerError::HttpRequest( + "http/request failed: error sending request for url (simulated no response)" + .to_string(), + )); + } + } + + inner.http_request_stream(params).await + } + .boxed() + } +} + +fn is_json_rpc_method(params: &HttpRequestParams, method: &str) -> bool { + if !params.method.eq_ignore_ascii_case("POST") { + return false; + } + + params + .body + .as_ref() + .and_then(|body| serde_json::from_slice::(&body.0).ok()) + .and_then(|body| { + body.get("method") + .and_then(Value::as_str) + .map(str::to_string) + }) + .is_some_and(|request_method| request_method == method) +} + +#[tokio::test(flavor = "multi_thread", worker_threads = 1)] +async fn streamable_http_initialize_retries_retryable_status() -> anyhow::Result<()> { + let (_server, base_url) = spawn_streamable_http_server().await?; + + arm_initialize_failure(&base_url, /*status*/ 503, /*remaining*/ 1).await?; + + let client = create_client(&base_url).await?; + let result = call_echo_tool(&client, "after-init-retry").await?; + assert_eq!(result, expected_echo_result("after-init-retry")); + + Ok(()) +} + +#[tokio::test(flavor = "multi_thread", worker_threads = 1)] +async fn streamable_http_initialize_retries_http_request_error() -> anyhow::Result<()> { + let (_server, base_url) = spawn_streamable_http_server().await?; + let http_client = FailFirstMethodHttpClient::new( + Environment::default_for_tests().get_http_client(), + "initialize", + ); + + let client = create_client_with_http_client(&base_url, Arc::new(http_client.clone())).await?; + let result = call_echo_tool(&client, "after-no-response-retry").await?; + + assert_eq!(http_client.matching_post_attempts(), 2); + assert_eq!(result, expected_echo_result("after-no-response-retry")); + + Ok(()) +} + +#[tokio::test(flavor = "multi_thread", worker_threads = 1)] +async fn streamable_http_tools_list_retries_retryable_status() -> anyhow::Result<()> { + let (_server, base_url) = spawn_streamable_http_server().await?; + let client = create_client(&base_url).await?; + + arm_session_post_failure( + &base_url, + /*status*/ 503, + /*remaining*/ 1, + /*www_authenticate_headers*/ &[], + ) + .await?; + + let tools = client + .list_tools(/*params*/ None, Some(Duration::from_secs(5))) + .await?; + + assert_eq!(tools.tools.len(), 1); + assert_eq!(tools.tools[0].name, "echo"); + + Ok(()) +} + +#[tokio::test(flavor = "multi_thread", worker_threads = 1)] +async fn streamable_http_tools_list_retries_http_request_error() -> anyhow::Result<()> { + let (_server, base_url) = spawn_streamable_http_server().await?; + let http_client = FailFirstMethodHttpClient::new( + Environment::default_for_tests().get_http_client(), + "tools/list", + ); + let client = create_client_with_http_client(&base_url, Arc::new(http_client.clone())).await?; + + let tools = client + .list_tools(/*params*/ None, Some(Duration::from_secs(5))) + .await?; + + assert_eq!(http_client.matching_post_attempts(), 2); + assert_eq!(tools.tools.len(), 1); + assert_eq!(tools.tools[0].name, "echo"); + + Ok(()) +} + +#[tokio::test(flavor = "multi_thread", worker_threads = 1)] +async fn streamable_http_initialize_does_not_retry_non_retryable_status() -> anyhow::Result<()> { + let (_server, base_url) = spawn_streamable_http_server().await?; + + arm_initialize_failure(&base_url, /*status*/ 403, /*remaining*/ 1).await?; + + let error = match create_client(&base_url).await { + Ok(_) => panic!("initialize unexpectedly succeeded after non-retryable HTTP 403"), + Err(error) => error, + }; + assert!(format!("{error:#}").contains("403")); + + Ok(()) +} + #[tokio::test(flavor = "multi_thread", worker_threads = 1)] async fn streamable_http_404_session_expiry_recovers_and_retries_once() -> anyhow::Result<()> { let (_server, base_url) = spawn_streamable_http_server().await?; diff --git a/codex-rs/rmcp-client/tests/streamable_http_test_support.rs b/codex-rs/rmcp-client/tests/streamable_http_test_support.rs index 822acef1a26..324bf06b2a2 100644 --- a/codex-rs/rmcp-client/tests/streamable_http_test_support.rs +++ b/codex-rs/rmcp-client/tests/streamable_http_test_support.rs @@ -20,6 +20,7 @@ use anyhow::Context as _; use codex_config::types::OAuthCredentialsStoreMode; use codex_exec_server::Environment; use codex_exec_server::ExecServerClient; +use codex_exec_server::HttpClient; use codex_exec_server::RemoteExecServerConnectArgs; use codex_rmcp_client::ElicitationAction; use codex_rmcp_client::ElicitationResponse; @@ -43,6 +44,7 @@ use tokio::process::Child; use tokio::process::Command; use tokio::time::sleep; +const INITIALIZE_FAILURE_CONTROL_PATH: &str = "/test/control/initialize-failure"; const SESSION_POST_FAILURE_CONTROL_PATH: &str = "/test/control/session-post-failure"; fn streamable_http_server_bin() -> Result { @@ -74,6 +76,14 @@ pub(crate) fn expected_echo_result(message: &str) -> CallToolResult { } pub(crate) async fn create_client(base_url: &str) -> anyhow::Result { + create_client_with_http_client(base_url, Environment::default_for_tests().get_http_client()) + .await +} + +pub(crate) async fn create_client_with_http_client( + base_url: &str, + http_client: Arc, +) -> anyhow::Result { let client = RmcpClient::new_streamable_http_client( "test-streamable-http", &format!("{base_url}/mcp"), @@ -81,7 +91,7 @@ pub(crate) async fn create_client(base_url: &str) -> anyhow::Result /*http_headers*/ None, /*env_http_headers*/ None, OAuthCredentialsStoreMode::File, - Environment::default_for_tests().get_http_client(), + http_client, /*auth_provider*/ None, ) .await?; @@ -178,6 +188,24 @@ pub(crate) async fn arm_session_post_failure( Ok(()) } +pub(crate) async fn arm_initialize_failure( + base_url: &str, + status: u16, + remaining: usize, +) -> anyhow::Result<()> { + let response = reqwest::Client::new() + .post(format!("{base_url}{INITIALIZE_FAILURE_CONTROL_PATH}")) + .json(&json!({ + "status": status, + "remaining": remaining, + })) + .send() + .await?; + + assert_eq!(response.status(), reqwest::StatusCode::NO_CONTENT); + Ok(()) +} + pub(crate) async fn spawn_streamable_http_server() -> anyhow::Result<(Child, String)> { let listener = TcpListener::bind("127.0.0.1:0")?; let port = listener.local_addr()?.port(); From a81e3add030224e4b643d12b4a637a9f4e1920b5 Mon Sep 17 00:00:00 2001 From: Sama Setty Date: Thu, 4 Jun 2026 00:19:08 +0000 Subject: [PATCH 02/15] Add MCP initialize outcome metric --- codex-rs/Cargo.lock | 1 + codex-rs/rmcp-client/Cargo.toml | 1 + codex-rs/rmcp-client/src/rmcp_client.rs | 138 +++++++++++++++++++++++- 3 files changed, 136 insertions(+), 4 deletions(-) diff --git a/codex-rs/Cargo.lock b/codex-rs/Cargo.lock index 2eeafe30b9e..e4ad366b1ba 100644 --- a/codex-rs/Cargo.lock +++ b/codex-rs/Cargo.lock @@ -3541,6 +3541,7 @@ dependencies = [ "codex-config", "codex-exec-server", "codex-keyring-store", + "codex-otel", "codex-protocol", "codex-utils-cargo-bin", "codex-utils-home-dir", diff --git a/codex-rs/rmcp-client/Cargo.toml b/codex-rs/rmcp-client/Cargo.toml index 63c6d74e4a5..0cf9b51af80 100644 --- a/codex-rs/rmcp-client/Cargo.toml +++ b/codex-rs/rmcp-client/Cargo.toml @@ -20,6 +20,7 @@ codex-client = { workspace = true } codex-config = { workspace = true } codex-exec-server = { workspace = true } codex-keyring-store = { workspace = true } +codex-otel = { workspace = true } codex-protocol = { workspace = true } codex-utils-pty = { workspace = true } codex-utils-home-dir = { workspace = true } diff --git a/codex-rs/rmcp-client/src/rmcp_client.rs b/codex-rs/rmcp-client/src/rmcp_client.rs index 09928fe82f5..8dc0a78adc3 100644 --- a/codex-rs/rmcp-client/src/rmcp_client.rs +++ b/codex-rs/rmcp-client/src/rmcp_client.rs @@ -75,6 +75,7 @@ use crate::utils::apply_default_headers; use crate::utils::build_default_headers; use codex_config::types::OAuthCredentialsStoreMode; +const MCP_CLIENT_INITIALIZE_METRIC: &str = "codex.mcp.client.initialize"; const STREAMABLE_HTTP_RETRY_DELAYS_MS: [u64; 2] = [250, 1_000]; enum PendingTransport { @@ -101,6 +102,53 @@ impl PendingTransport { | PendingTransport::StreamableHttpWithOAuth { .. } ) } + + fn metric_transport(&self) -> &'static str { + match self { + PendingTransport::InProcess { .. } => "in_process", + PendingTransport::Stdio { .. } => "stdio", + PendingTransport::StreamableHttp { .. } + | PendingTransport::StreamableHttpWithOAuth { .. } => "streamable_http", + } + } +} + +fn initialize_metric_tags( + transport: &'static str, + outcome: &'static str, + attempts: usize, + retry_exhausted: bool, + failure_kind: &'static str, +) -> Vec<(&'static str, String)> { + let attempts = attempts.max(1); + vec![ + ("transport", transport.to_string()), + ("outcome", outcome.to_string()), + ("retried", (attempts > 1).to_string()), + ("attempts", attempts.to_string()), + ("retry_count", attempts.saturating_sub(1).to_string()), + ("retry_exhausted", retry_exhausted.to_string()), + ("failure_kind", failure_kind.to_string()), + ] +} + +fn emit_initialize_metric( + transport: &'static str, + outcome: &'static str, + attempts: usize, + retry_exhausted: bool, + failure_kind: &'static str, +) { + let Some(metrics) = codex_otel::global() else { + return; + }; + + let tags = initialize_metric_tags(transport, outcome, attempts, retry_exhausted, failure_kind); + let tag_refs: Vec<(&str, &str)> = tags + .iter() + .map(|(key, value)| (*key, value.as_str())) + .collect(); + let _ = metrics.counter(MCP_CLIENT_INITIALIZE_METRIC, /*inc*/ 1, &tag_refs); } enum ClientState { @@ -892,6 +940,7 @@ impl RmcpClient { Option, )> { let should_retry = initial_transport.is_streamable_http(); + let metric_transport = initial_transport.metric_transport(); let mut pending_transport = Some(initial_transport); let retry_schedule = STREAMABLE_HTTP_RETRY_DELAYS_MS @@ -901,21 +950,50 @@ impl RmcpClient { .chain(std::iter::once(None)); for (attempt, retry_delay_ms) in retry_schedule.enumerate() { + let attempt_count = attempt + 1; let transport = match pending_transport.take() { Some(transport) => transport, - None => Self::create_pending_transport(&self.transport_recipe).await?, + None => match Self::create_pending_transport(&self.transport_recipe).await { + Ok(transport) => transport, + Err(error) => { + emit_initialize_metric( + metric_transport, + "error", + attempt_count, + /*retry_exhausted*/ false, + "transport_create", + ); + return Err(error); + } + }, }; match Self::connect_pending_transport(transport, client_service.clone(), timeout).await { - Ok(result) => return Ok(result), + Ok(result) => { + emit_initialize_metric( + metric_transport, + "success", + attempt_count, + /*retry_exhausted*/ false, + "none", + ); + return Ok(result); + } Err(error) if should_retry && Self::is_retryable_initialize_error(&error) => { let Some(retry_delay_ms) = retry_delay_ms else { + emit_initialize_metric( + metric_transport, + "error", + attempt_count, + /*retry_exhausted*/ true, + "retry_exhausted", + ); return Err(error); }; let delay = Duration::from_millis(retry_delay_ms); warn!( - attempt = attempt + 1, + attempt = attempt_count, max_attempts = STREAMABLE_HTTP_RETRY_DELAYS_MS.len() + 1, delay_ms = delay.as_millis(), error = %error, @@ -923,7 +1001,16 @@ impl RmcpClient { ); time::sleep(delay).await; } - Err(error) => return Err(error), + Err(error) => { + emit_initialize_metric( + metric_transport, + "error", + attempt_count, + /*retry_exhausted*/ false, + "non_retryable", + ); + return Err(error); + } } } @@ -1211,6 +1298,7 @@ async fn create_oauth_transport_and_runtime( #[cfg(test)] mod tests { + use std::collections::BTreeMap; use std::time::Duration; use pretty_assertions::assert_eq; @@ -1218,6 +1306,10 @@ mod tests { use super::*; + fn metric_tags_map(tags: Vec<(&'static str, String)>) -> BTreeMap<&'static str, String> { + tags.into_iter().collect() + } + #[tokio::test] async fn active_time_timeout_pauses_while_elicitation_is_pending() { let pause_state = ElicitationPauseState::new(); @@ -1236,4 +1328,42 @@ mod tests { assert_eq!(Ok("done"), result); } + + #[test] + fn initialize_metric_tags_record_success_after_retry() { + let tags = metric_tags_map(initialize_metric_tags( + "streamable_http", + "success", + 2, + /*retry_exhausted*/ false, + "none", + )); + + assert_eq!(tags["transport"], "streamable_http"); + assert_eq!(tags["outcome"], "success"); + assert_eq!(tags["retried"], "true"); + assert_eq!(tags["attempts"], "2"); + assert_eq!(tags["retry_count"], "1"); + assert_eq!(tags["retry_exhausted"], "false"); + assert_eq!(tags["failure_kind"], "none"); + } + + #[test] + fn initialize_metric_tags_record_retry_exhaustion() { + let tags = metric_tags_map(initialize_metric_tags( + "streamable_http", + "error", + 3, + /*retry_exhausted*/ true, + "retry_exhausted", + )); + + assert_eq!(tags["transport"], "streamable_http"); + assert_eq!(tags["outcome"], "error"); + assert_eq!(tags["retried"], "true"); + assert_eq!(tags["attempts"], "3"); + assert_eq!(tags["retry_count"], "2"); + assert_eq!(tags["retry_exhausted"], "true"); + assert_eq!(tags["failure_kind"], "retry_exhausted"); + } } From 089108666e4d86d9b579417f96efb7b76d31416e Mon Sep 17 00:00:00 2001 From: Sama Setty Date: Thu, 4 Jun 2026 16:54:41 -0700 Subject: [PATCH 03/15] Retry remote streamable HTTP no-response failures --- codex-rs/rmcp-client/src/rmcp_client.rs | 18 +- .../tests/streamable_http_recovery.rs | 72 ++++- .../tests/streamable_http_remote.rs | 248 ++++++++++++++++++ 3 files changed, 326 insertions(+), 12 deletions(-) diff --git a/codex-rs/rmcp-client/src/rmcp_client.rs b/codex-rs/rmcp-client/src/rmcp_client.rs index 8dc0a78adc3..6ca1d88dbfc 100644 --- a/codex-rs/rmcp-client/src/rmcp_client.rs +++ b/codex-rs/rmcp-client/src/rmcp_client.rs @@ -76,6 +76,7 @@ use crate::utils::build_default_headers; use codex_config::types::OAuthCredentialsStoreMode; const MCP_CLIENT_INITIALIZE_METRIC: &str = "codex.mcp.client.initialize"; +const JSON_RPC_INTERNAL_ERROR_CODE: i64 = -32603; const STREAMABLE_HTTP_RETRY_DELAYS_MS: [u64; 2] = [250, 1_000]; enum PendingTransport { @@ -1165,15 +1166,18 @@ impl RmcpClient { fn is_retryable_streamable_http_error( error: &StreamableHttpError, ) -> bool { - matches!( - error, + match error { StreamableHttpError::Client( StreamableHttpClientAdapterError::RetryableHttpStatus(_) - | StreamableHttpClientAdapterError::HttpRequest(ExecServerError::HttpRequest( - _ - )) - ) - ) + | StreamableHttpClientAdapterError::HttpRequest(ExecServerError::HttpRequest(_)), + ) => true, + StreamableHttpError::Client(StreamableHttpClientAdapterError::HttpRequest( + ExecServerError::Server { code, message }, + )) => { + *code == JSON_RPC_INTERNAL_ERROR_CODE && message.starts_with("http/request failed:") + } + _ => false, + } } async fn reinitialize_after_session_expiry( diff --git a/codex-rs/rmcp-client/tests/streamable_http_recovery.rs b/codex-rs/rmcp-client/tests/streamable_http_recovery.rs index fd2a19072e3..ceb10508797 100644 --- a/codex-rs/rmcp-client/tests/streamable_http_recovery.rs +++ b/codex-rs/rmcp-client/tests/streamable_http_recovery.rs @@ -24,19 +24,31 @@ use streamable_http_test_support::create_client_with_http_client; use streamable_http_test_support::expected_echo_result; use streamable_http_test_support::spawn_streamable_http_server; +const JSON_RPC_INTERNAL_ERROR_CODE: i64 = -32603; +const SIMULATED_NO_RESPONSE_MESSAGE: &str = + "http/request failed: error sending request for url (simulated no response)"; + +#[derive(Clone, Copy)] +enum RequestFailure { + LocalHttpRequest, + RemoteServer, +} + #[derive(Clone)] struct FailFirstMethodHttpClient { inner: Arc, method: &'static str, + failure: RequestFailure, failures_remaining: Arc, matching_post_attempts: Arc, } impl FailFirstMethodHttpClient { - fn new(inner: Arc, method: &'static str) -> Self { + fn new(inner: Arc, method: &'static str, failure: RequestFailure) -> Self { Self { inner, method, + failure, failures_remaining: Arc::new(AtomicUsize::new(1)), matching_post_attempts: Arc::new(AtomicUsize::new(0)), } @@ -61,6 +73,7 @@ impl HttpClient for FailFirstMethodHttpClient { ) -> BoxFuture<'_, Result<(HttpRequestResponse, HttpResponseBodyStream), ExecServerError>> { let inner = Arc::clone(&self.inner); let method = self.method; + let failure = self.failure; let failures_remaining = Arc::clone(&self.failures_remaining); let matching_post_attempts = Arc::clone(&self.matching_post_attempts); @@ -73,10 +86,15 @@ impl HttpClient for FailFirstMethodHttpClient { }) .is_ok() { - return Err(ExecServerError::HttpRequest( - "http/request failed: error sending request for url (simulated no response)" - .to_string(), - )); + return Err(match failure { + RequestFailure::LocalHttpRequest => { + ExecServerError::HttpRequest(SIMULATED_NO_RESPONSE_MESSAGE.to_string()) + } + RequestFailure::RemoteServer => ExecServerError::Server { + code: JSON_RPC_INTERNAL_ERROR_CODE, + message: SIMULATED_NO_RESPONSE_MESSAGE.to_string(), + }, + }); } } @@ -122,6 +140,7 @@ async fn streamable_http_initialize_retries_http_request_error() -> anyhow::Resu let http_client = FailFirstMethodHttpClient::new( Environment::default_for_tests().get_http_client(), "initialize", + RequestFailure::LocalHttpRequest, ); let client = create_client_with_http_client(&base_url, Arc::new(http_client.clone())).await?; @@ -133,6 +152,27 @@ async fn streamable_http_initialize_retries_http_request_error() -> anyhow::Resu Ok(()) } +#[tokio::test(flavor = "multi_thread", worker_threads = 1)] +async fn streamable_http_initialize_retries_remote_http_request_error() -> anyhow::Result<()> { + let (_server, base_url) = spawn_streamable_http_server().await?; + let http_client = FailFirstMethodHttpClient::new( + Environment::default_for_tests().get_http_client(), + "initialize", + RequestFailure::RemoteServer, + ); + + let client = create_client_with_http_client(&base_url, Arc::new(http_client.clone())).await?; + let result = call_echo_tool(&client, "after-remote-no-response-retry").await?; + + assert_eq!(http_client.matching_post_attempts(), 2); + assert_eq!( + result, + expected_echo_result("after-remote-no-response-retry") + ); + + Ok(()) +} + #[tokio::test(flavor = "multi_thread", worker_threads = 1)] async fn streamable_http_tools_list_retries_retryable_status() -> anyhow::Result<()> { let (_server, base_url) = spawn_streamable_http_server().await?; @@ -162,6 +202,28 @@ async fn streamable_http_tools_list_retries_http_request_error() -> anyhow::Resu let http_client = FailFirstMethodHttpClient::new( Environment::default_for_tests().get_http_client(), "tools/list", + RequestFailure::LocalHttpRequest, + ); + let client = create_client_with_http_client(&base_url, Arc::new(http_client.clone())).await?; + + let tools = client + .list_tools(/*params*/ None, Some(Duration::from_secs(5))) + .await?; + + assert_eq!(http_client.matching_post_attempts(), 2); + assert_eq!(tools.tools.len(), 1); + assert_eq!(tools.tools[0].name, "echo"); + + Ok(()) +} + +#[tokio::test(flavor = "multi_thread", worker_threads = 1)] +async fn streamable_http_tools_list_retries_remote_http_request_error() -> anyhow::Result<()> { + let (_server, base_url) = spawn_streamable_http_server().await?; + let http_client = FailFirstMethodHttpClient::new( + Environment::default_for_tests().get_http_client(), + "tools/list", + RequestFailure::RemoteServer, ); let client = create_client_with_http_client(&base_url, Arc::new(http_client.clone())).await?; diff --git a/codex-rs/rmcp-client/tests/streamable_http_remote.rs b/codex-rs/rmcp-client/tests/streamable_http_remote.rs index 0d4690a8255..69f274910d7 100644 --- a/codex-rs/rmcp-client/tests/streamable_http_remote.rs +++ b/codex-rs/rmcp-client/tests/streamable_http_remote.rs @@ -6,7 +6,19 @@ mod streamable_http_test_support; +use std::net::SocketAddr; +use std::sync::Arc; +use std::sync::atomic::AtomicUsize; +use std::sync::atomic::Ordering; +use std::time::Duration; + +use anyhow::Context as _; use pretty_assertions::assert_eq; +use tokio::io::AsyncReadExt; +use tokio::io::AsyncWriteExt; +use tokio::net::TcpListener; +use tokio::net::TcpStream; +use tokio::task::JoinHandle; use streamable_http_test_support::call_echo_tool; use streamable_http_test_support::create_remote_client; @@ -35,3 +47,239 @@ async fn streamable_http_remote_client_round_trips_through_exec_server() -> anyh Ok(()) } + +/// What this tests: when a real remote exec-server sees a no-status network +/// failure during the Streamable HTTP initialize request, it maps the reqwest +/// send failure into a JSON-RPC internal server error and the RMCP client still +/// treats that remote-shaped error as retryable. +#[tokio::test(flavor = "multi_thread", worker_threads = 2)] +async fn streamable_http_remote_initialize_retries_no_response_failure() -> anyhow::Result<()> { + let (_server, base_url) = spawn_streamable_http_server().await?; + let proxy = DropNextMcpPostProxy::spawn(&base_url).await?; + proxy.arm_next_mcp_post_drop(); + let exec_server = spawn_exec_server().await?; + + let client = create_remote_client(proxy.base_url(), exec_server.client.clone()).await?; + let result = call_echo_tool(&client, "remote-init-retry").await?; + + assert_eq!(proxy.dropped_mcp_posts(), 1); + assert_eq!(result, expected_echo_result("remote-init-retry")); + + Ok(()) +} + +/// What this tests: once initialized through the real remote exec-server path, +/// a no-status Streamable HTTP failure during tools/list is retried instead of +/// surfacing the remote JSON-RPC internal server error to the caller. +#[tokio::test(flavor = "multi_thread", worker_threads = 2)] +async fn streamable_http_remote_tools_list_retries_no_response_failure() -> anyhow::Result<()> { + let (_server, base_url) = spawn_streamable_http_server().await?; + let proxy = DropNextMcpPostProxy::spawn(&base_url).await?; + let exec_server = spawn_exec_server().await?; + let client = create_remote_client(proxy.base_url(), exec_server.client.clone()).await?; + + proxy.arm_next_mcp_post_drop(); + let tools = client + .list_tools(/*params*/ None, Some(Duration::from_secs(5))) + .await?; + + assert_eq!(proxy.dropped_mcp_posts(), 1); + assert_eq!(tools.tools.len(), 1); + assert_eq!(tools.tools[0].name, "echo"); + + Ok(()) +} + +struct DropNextMcpPostProxy { + base_url: String, + drops_remaining: Arc, + dropped_mcp_posts: Arc, + task: JoinHandle<()>, +} + +impl DropNextMcpPostProxy { + async fn spawn(target_base_url: &str) -> anyhow::Result { + let target_addr = parse_target_addr(target_base_url)?; + let listener = TcpListener::bind("127.0.0.1:0").await?; + let proxy_addr = listener.local_addr()?; + let drops_remaining = Arc::new(AtomicUsize::new(0)); + let dropped_mcp_posts = Arc::new(AtomicUsize::new(0)); + let task_drops_remaining = Arc::clone(&drops_remaining); + let task_dropped_mcp_posts = Arc::clone(&dropped_mcp_posts); + + let task = tokio::spawn(async move { + while let Ok((client, _addr)) = listener.accept().await { + let connection_drops_remaining = Arc::clone(&task_drops_remaining); + let connection_dropped_mcp_posts = Arc::clone(&task_dropped_mcp_posts); + tokio::spawn(async move { + let _ = proxy_connection( + client, + target_addr, + connection_drops_remaining, + connection_dropped_mcp_posts, + ) + .await; + }); + } + }); + + Ok(Self { + base_url: format!("http://{proxy_addr}"), + drops_remaining, + dropped_mcp_posts, + task, + }) + } + + fn base_url(&self) -> &str { + &self.base_url + } + + fn arm_next_mcp_post_drop(&self) { + self.drops_remaining.fetch_add(1, Ordering::SeqCst); + } + + fn dropped_mcp_posts(&self) -> usize { + self.dropped_mcp_posts.load(Ordering::SeqCst) + } +} + +impl Drop for DropNextMcpPostProxy { + fn drop(&mut self) { + self.task.abort(); + } +} + +async fn proxy_connection( + mut client: TcpStream, + target_addr: SocketAddr, + drops_remaining: Arc, + dropped_mcp_posts: Arc, +) -> anyhow::Result<()> { + let request = read_http_message(&mut client).await?; + if request.is_empty() { + return Ok(()); + } + + if is_mcp_post(&request) && consume_drop(&drops_remaining) { + dropped_mcp_posts.fetch_add(1, Ordering::SeqCst); + return Ok(()); + } + + let request = with_connection_close(request)?; + let mut upstream = TcpStream::connect(target_addr).await?; + upstream.write_all(&request).await?; + tokio::io::copy(&mut upstream, &mut client).await?; + client.shutdown().await?; + + Ok(()) +} + +async fn read_http_message(stream: &mut TcpStream) -> anyhow::Result> { + let mut message = Vec::new(); + let mut header_end = None; + let mut chunk = [0_u8; 4096]; + + while header_end.is_none() { + let bytes_read = stream.read(&mut chunk).await?; + if bytes_read == 0 { + return Ok(message); + } + message.extend_from_slice(&chunk[..bytes_read]); + header_end = find_header_end(&message); + } + + let header_end = header_end.context("HTTP message headers were not terminated")?; + let content_length = content_length(&message[..header_end])?; + let message_len = header_end + content_length; + + while message.len() < message_len { + let bytes_read = stream.read(&mut chunk).await?; + if bytes_read == 0 { + anyhow::bail!("HTTP message ended before body was complete"); + } + message.extend_from_slice(&chunk[..bytes_read]); + } + + Ok(message) +} + +fn find_header_end(bytes: &[u8]) -> Option { + bytes + .windows(4) + .position(|window| window == b"\r\n\r\n") + .map(|position| position + 4) +} + +fn content_length(headers: &[u8]) -> anyhow::Result { + let headers = std::str::from_utf8(headers).context("HTTP headers were not UTF-8")?; + for line in headers.lines().skip(1) { + let Some((name, value)) = line.split_once(':') else { + continue; + }; + if name.eq_ignore_ascii_case("content-length") { + return value + .trim() + .parse::() + .context("Content-Length header was not a usize"); + } + } + Ok(0) +} + +fn is_mcp_post(request: &[u8]) -> bool { + let Some(request_line) = std::str::from_utf8(request) + .ok() + .and_then(|request| request.lines().next()) + else { + return false; + }; + let mut parts = request_line.split_whitespace(); + parts.next() == Some("POST") && parts.next() == Some("/mcp") +} + +fn consume_drop(drops_remaining: &AtomicUsize) -> bool { + drops_remaining + .fetch_update(Ordering::SeqCst, Ordering::SeqCst, |remaining| { + remaining.checked_sub(1) + }) + .is_ok() +} + +fn with_connection_close(request: Vec) -> anyhow::Result> { + let header_end = find_header_end(&request).context("HTTP request headers were not complete")?; + let headers = std::str::from_utf8(&request[..header_end]).context("request was not UTF-8")?; + let mut next_request = Vec::with_capacity(request.len() + "Connection: close\r\n".len()); + + for line in headers + .strip_suffix("\r\n\r\n") + .unwrap_or(headers) + .split("\r\n") + { + if line + .split_once(':') + .is_some_and(|(name, _value)| name.eq_ignore_ascii_case("connection")) + { + continue; + } + next_request.extend_from_slice(line.as_bytes()); + next_request.extend_from_slice(b"\r\n"); + } + next_request.extend_from_slice(b"Connection: close\r\n\r\n"); + next_request.extend_from_slice(&request[header_end..]); + + Ok(next_request) +} + +fn parse_target_addr(base_url: &str) -> anyhow::Result { + let url = reqwest::Url::parse(base_url)?; + let host = url + .host_str() + .context("target URL did not include a host")?; + let port = url + .port_or_known_default() + .context("target URL did not include a port")?; + format!("{host}:{port}") + .parse() + .context("target URL did not resolve to a socket address") +} From d00ea62d33dd4c24a57150901dd061be0e51eaad Mon Sep 17 00:00:00 2001 From: Sama Setty Date: Thu, 4 Jun 2026 17:03:49 -0700 Subject: [PATCH 04/15] codex: fix CI fmt-check recipe on PR #25147 --- justfile | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/justfile b/justfile index d5e9fc3a36e..a088978ac8b 100644 --- a/justfile +++ b/justfile @@ -36,6 +36,12 @@ fmt: uv run --frozen --project ../sdk/python --extra dev ruff check --fix --fix-only ../sdk/python uv run --frozen --project ../sdk/python --extra dev ruff format ../sdk/python +# Check formatting without modifying files. +fmt-check: + cargo fmt -- --check --config imports_granularity=Item 2>/dev/null + uv run --frozen --project ../sdk/python --extra dev ruff check --diff ../sdk/python + uv run --frozen --project ../sdk/python --extra dev ruff format --check ../sdk/python + fix *args: cargo clippy --fix --tests --allow-dirty "$@" From a709d6b6d008ccca32fc26afe318bc03dd76acd9 Mon Sep 17 00:00:00 2001 From: Sama Setty Date: Thu, 4 Jun 2026 17:36:07 -0700 Subject: [PATCH 05/15] codex: fix CI failures on PR #25147 --- codex-rs/rmcp-client/src/rmcp_client.rs | 4 ++-- sdk/python/pyproject.toml | 6 +++--- sdk/python/uv.lock | 20 +++++++++++--------- 3 files changed, 16 insertions(+), 14 deletions(-) diff --git a/codex-rs/rmcp-client/src/rmcp_client.rs b/codex-rs/rmcp-client/src/rmcp_client.rs index 6ca1d88dbfc..b1e4c3dbd38 100644 --- a/codex-rs/rmcp-client/src/rmcp_client.rs +++ b/codex-rs/rmcp-client/src/rmcp_client.rs @@ -1338,7 +1338,7 @@ mod tests { let tags = metric_tags_map(initialize_metric_tags( "streamable_http", "success", - 2, + /*attempts*/ 2, /*retry_exhausted*/ false, "none", )); @@ -1357,7 +1357,7 @@ mod tests { let tags = metric_tags_map(initialize_metric_tags( "streamable_http", "error", - 3, + /*attempts*/ 3, /*retry_exhausted*/ true, "retry_exhausted", )); diff --git a/sdk/python/pyproject.toml b/sdk/python/pyproject.toml index 9c93be69a5d..b70df0cb534 100644 --- a/sdk/python/pyproject.toml +++ b/sdk/python/pyproject.toml @@ -16,7 +16,7 @@ classifiers = [ "Intended Audience :: Developers", "Topic :: Software Development :: Libraries :: Python Modules", ] -dependencies = ["pydantic>=2.12", "openai-codex-cli-bin==0.132.0"] +dependencies = ["pydantic>=2.12", "openai-codex-cli-bin==0.137.0a4"] [project.urls] Homepage = "https://github.com/openai/codex" @@ -78,10 +78,10 @@ combine-as-imports = true [tool.uv] exclude-newer = "7 days" -exclude-newer-package = { openai-codex-cli-bin = "2026-05-20T21:00:00Z" } +exclude-newer-package = { openai-codex-cli-bin = "2026-06-03T19:00:00Z" } index-strategy = "first-index" [tool.uv.pip] exclude-newer = "7 days" -exclude-newer-package = { openai-codex-cli-bin = "2026-05-20T21:00:00Z" } +exclude-newer-package = { openai-codex-cli-bin = "2026-06-03T19:00:00Z" } index-strategy = "first-index" diff --git a/sdk/python/uv.lock b/sdk/python/uv.lock index 6d9e40e37a0..588c9f4fe2b 100644 --- a/sdk/python/uv.lock +++ b/sdk/python/uv.lock @@ -7,7 +7,7 @@ exclude-newer = "0001-01-01T00:00:00Z" # This has no effect and is included for exclude-newer-span = "P7D" [options.exclude-newer-package] -openai-codex-cli-bin = "2026-05-20T21:00:00Z" +openai-codex-cli-bin = "2026-06-03T19:00:00Z" [[package]] name = "annotated-types" @@ -299,7 +299,7 @@ dev = [ [package.metadata] requires-dist = [ { name = "datamodel-code-generator", marker = "extra == 'dev'", specifier = "==0.31.2" }, - { name = "openai-codex-cli-bin", specifier = "==0.132.0" }, + { name = "openai-codex-cli-bin", specifier = "==0.137.0a4" }, { name = "pydantic", specifier = ">=2.12" }, { name = "pytest", marker = "extra == 'dev'", specifier = ">=8.0" }, { name = "ruff", marker = "extra == 'dev'", specifier = ">=0.15.8" }, @@ -308,15 +308,17 @@ provides-extras = ["dev"] [[package]] name = "openai-codex-cli-bin" -version = "0.132.0" +version = "0.137.0a4" source = { registry = "https://pypi.org/simple" } wheels = [ - { url = "https://files.pythonhosted.org/packages/be/a1/b92b7a1b73a83785d2e1dcd0faecd1b7f886a38cf02a30abe1c35f42f0f7/openai_codex_cli_bin-0.132.0-py3-none-macosx_10_9_x86_64.whl", hash = "sha256:1c22b51dbd679413f84f00b9d8fd4e5cf8a1c0d1c7cc8c42bcb3f9f1b33e2334", size = 89403211, upload-time = "2026-05-20T02:37:22.311Z" }, - { url = "https://files.pythonhosted.org/packages/5f/68/163272e582de55a7f460e2329281267908d75d0fbcbbbb2c6749a6329e6b/openai_codex_cli_bin-0.132.0-py3-none-macosx_11_0_arm64.whl", hash = "sha256:56217495e6635c8a5d96df820cc0da5f46cd9b6ec6f3a5f67f1607d69ef74256", size = 79058685, upload-time = "2026-05-20T02:37:27.165Z" }, - { url = "https://files.pythonhosted.org/packages/0b/18/a60c6b137e7cd3959cae16ba757f57ca5702979b0ea107a21f516ba15d98/openai_codex_cli_bin-0.132.0-py3-none-musllinux_1_1_aarch64.whl", hash = "sha256:09642e7578a3078bfccc82af4077b085d42b0022b529e4b5c645e0a0af3397a4", size = 78689038, upload-time = "2026-05-20T02:37:31.548Z" }, - { url = "https://files.pythonhosted.org/packages/f8/eb/1b184307a67c1006d59b61636bcfcea73a89aa95271f6516ed28dce554ca/openai_codex_cli_bin-0.132.0-py3-none-musllinux_1_1_x86_64.whl", hash = "sha256:85aec095f9d144d7a2d1aff39fce77b7240f42014580c35801ba74b9317aa5f7", size = 85528820, upload-time = "2026-05-20T02:37:36.559Z" }, - { url = "https://files.pythonhosted.org/packages/0e/e8/1b823a8bf7b96d1513905ad79b16a146d797f81a19a6bc350a2f95a16661/openai_codex_cli_bin-0.132.0-py3-none-win_amd64.whl", hash = "sha256:3cb5c90c55baa39bd5ddc890d2068d3e1322a57a54d1d0e623819009a205c7f5", size = 86916218, upload-time = "2026-05-20T02:37:41.886Z" }, - { url = "https://files.pythonhosted.org/packages/6b/e6/bb8634bd4f3adaea299c95d7b03105ac417e32dd6d8bc2af5dda141d6f28/openai_codex_cli_bin-0.132.0-py3-none-win_arm64.whl", hash = "sha256:74ef93d3deef7cb83c71d19fc667defe749cdab337ec331f59a23511561b6f6a", size = 79892931, upload-time = "2026-05-20T02:37:46.828Z" }, + { url = "https://files.pythonhosted.org/packages/bd/60/af73ef1676cd477fa83ed4b889bf3b57c63c47dd87025b2cc4262793cff6/openai_codex_cli_bin-0.137.0a4-py3-none-macosx_10_9_x86_64.whl", hash = "sha256:b33c3917e0b58d527ee11a11a78ad390f7d8e6aa25577dd21665ab3c8bf5cf9a", size = 94300191, upload-time = "2026-06-03T18:44:36.312Z" }, + { url = "https://files.pythonhosted.org/packages/92/8f/d1a5f8c87176e00ef6a85798794f4530f5eb04e5a1a13468b5b3c3a361f9/openai_codex_cli_bin-0.137.0a4-py3-none-macosx_11_0_arm64.whl", hash = "sha256:3d0f0bc5becc88c61952fbfa9bd792ac9d74fa78b3a6bd40f545b612048b07eb", size = 83924479, upload-time = "2026-06-03T18:44:40.854Z" }, + { url = "https://files.pythonhosted.org/packages/3e/3c/fc00bcdc0c302208317d5eb1d0bfaab3024f351cd0121400f19baa6b19aa/openai_codex_cli_bin-0.137.0a4-py3-none-manylinux_2_17_aarch64.whl", hash = "sha256:2f1656339e2736868c4cce59f6d9e5c633879123687169b03b1137d42bf2c11a", size = 83363315, upload-time = "2026-06-03T18:44:44.851Z" }, + { url = "https://files.pythonhosted.org/packages/ec/09/39362e944ebeb12fcbfb86881fbb4dd6e806f77f7541c1f1f993bb9351a0/openai_codex_cli_bin-0.137.0a4-py3-none-manylinux_2_17_x86_64.whl", hash = "sha256:6454f838d44c56c1ed07a29b391fa412785e5dd2ffd06db0b62e62478c19bb64", size = 90611239, upload-time = "2026-06-03T18:44:49.338Z" }, + { url = "https://files.pythonhosted.org/packages/fa/38/87b1247fdfe95cddce7f7fe8331d6843cf037e14292c0f5004e23247133b/openai_codex_cli_bin-0.137.0a4-py3-none-musllinux_1_1_aarch64.whl", hash = "sha256:f5ae7401d00c65d56a75d9645d7bf87d809566a12d238e4b2a8b328a02f2316e", size = 83363315, upload-time = "2026-06-03T18:44:53.428Z" }, + { url = "https://files.pythonhosted.org/packages/fb/c4/3c693ad07e587f6b3a28128c417f2e831d81a40cdbd85c0e5f0f36aaff82/openai_codex_cli_bin-0.137.0a4-py3-none-musllinux_1_1_x86_64.whl", hash = "sha256:3dcec1e649448be498d6e7ec0e1f71dca83efa76063d90890dafb41e987069b7", size = 90611238, upload-time = "2026-06-03T18:44:57.612Z" }, + { url = "https://files.pythonhosted.org/packages/9e/26/81e037066b9b8d312a6f9e09015e452ce17630d5ab88e02a4c1d9503e4e8/openai_codex_cli_bin-0.137.0a4-py3-none-win_amd64.whl", hash = "sha256:9e13bf68e18e36bd3a0efd51213281c83e9f6ec22bdb7a45bd2e0211822733a9", size = 94744969, upload-time = "2026-06-03T18:45:02.23Z" }, + { url = "https://files.pythonhosted.org/packages/0d/a3/952bc2a5d62373a51fea161effe3b338b3417c2f6e65fe467ed91b205e2b/openai_codex_cli_bin-0.137.0a4-py3-none-win_arm64.whl", hash = "sha256:5ec4303ca2dcb5f838e0de3ca7f44050b6bcdd41d281a178c3a1420a985a515d", size = 86963504, upload-time = "2026-06-03T18:45:07.131Z" }, ] [[package]] From 00ee46ca9f71db45b48e813b64212be82d474748 Mon Sep 17 00:00:00 2001 From: Sama Setty Date: Thu, 4 Jun 2026 17:54:13 -0700 Subject: [PATCH 06/15] codex: address PR review feedback (#25147) --- .../rmcp-client/src/http_client_adapter.rs | 102 +++++++++++- codex-rs/rmcp-client/src/rmcp_client.rs | 150 +++++++++--------- codex-rs/rmcp-client/src/rmcp_client_tests.rs | 67 ++++++++ .../tests/streamable_http_recovery.rs | 110 ++++++++++++- .../tests/streamable_http_test_support.rs | 2 +- 5 files changed, 342 insertions(+), 89 deletions(-) create mode 100644 codex-rs/rmcp-client/src/rmcp_client_tests.rs diff --git a/codex-rs/rmcp-client/src/http_client_adapter.rs b/codex-rs/rmcp-client/src/http_client_adapter.rs index 23ec8fca369..fbb8039659f 100644 --- a/codex-rs/rmcp-client/src/http_client_adapter.rs +++ b/codex-rs/rmcp-client/src/http_client_adapter.rs @@ -184,25 +184,29 @@ impl StreamableHttpClient for StreamableHttpClientAdapter { ) { return Ok(StreamableHttpPostResponse::Accepted); } + + let content_type = response_header(&response.headers, CONTENT_TYPE); + let session_id = response_header(&response.headers, HEADER_SESSION_ID); + if let Some(content_type) = content_type.as_deref() + && content_type.starts_with(JSON_MIME_TYPE) + { + let body = collect_body(&mut body_stream).await?; + let message: ServerJsonRpcMessage = + serde_json::from_slice(&body).map_err(StreamableHttpError::Deserialize)?; + return Ok(StreamableHttpPostResponse::Json(message, session_id)); + } + if is_retryable_http_status(response.status) { return Err(StreamableHttpError::Client( StreamableHttpClientAdapterError::RetryableHttpStatus(response.status), )); } - let content_type = response_header(&response.headers, CONTENT_TYPE); - let session_id = response_header(&response.headers, HEADER_SESSION_ID); match content_type.as_deref() { Some(content_type) if content_type.starts_with(EVENT_STREAM_MIME_TYPE) => { let event_stream = sse_stream_from_body(body_stream); Ok(StreamableHttpPostResponse::Sse(event_stream, session_id)) } - Some(content_type) if content_type.starts_with(JSON_MIME_TYPE) => { - let body = collect_body(&mut body_stream).await?; - let message: ServerJsonRpcMessage = - serde_json::from_slice(&body).map_err(StreamableHttpError::Deserialize)?; - Ok(StreamableHttpPostResponse::Json(message, session_id)) - } _ => { let body = collect_body(&mut body_stream).await?; let content_type = content_type.unwrap_or_else(|| "missing-content-type".into()); @@ -501,3 +505,85 @@ fn sse_stream_from_body( })) .boxed() } + +#[cfg(test)] +mod tests { + use axum::Json; + use axum::Router; + use axum::http::StatusCode; + use axum::response::IntoResponse; + use axum::routing::post; + use codex_exec_server::Environment; + use pretty_assertions::assert_eq; + use rmcp::model::ClientRequest; + use rmcp::model::ErrorData; + use rmcp::model::JsonRpcError; + use rmcp::model::PingRequest; + use rmcp::model::RequestId; + use rmcp::transport::streamable_http_client::StreamableHttpPostResponse; + use serde_json::json; + use tokio::net::TcpListener; + + use super::*; + + #[tokio::test] + async fn post_message_parses_json_error_body_before_retryable_status() -> anyhow::Result<()> { + let listener = TcpListener::bind("127.0.0.1:0").await?; + let address = listener.local_addr()?; + let app = Router::new().route("/", post(json_error_response)); + let server = tokio::spawn(async move { axum::serve(listener, app).await }); + + let adapter = StreamableHttpClientAdapter::new( + Environment::default_for_tests().get_http_client(), + HeaderMap::new(), + /*auth_provider*/ None, + ); + let request = ClientJsonRpcMessage::request( + ClientRequest::PingRequest(PingRequest::default()), + RequestId::Number(1), + ); + + let response = adapter + .post_message( + Arc::from(format!("http://{address}/")), + request, + /*session_id*/ None, + /*auth_token*/ None, + HashMap::new(), + ) + .await?; + + server.abort(); + + let StreamableHttpPostResponse::Json(message, _session_id) = response else { + panic!("expected JSON response"); + }; + let ServerJsonRpcMessage::Error(error) = message else { + panic!("expected JSON-RPC error"); + }; + assert_eq!( + error, + JsonRpcError::new( + /*id*/ Some(RequestId::Number(1)), + ErrorData::internal_error("transient json error", /*data*/ None), + ) + ); + + Ok(()) + } + + async fn json_error_response() -> impl IntoResponse { + ( + StatusCode::INTERNAL_SERVER_ERROR, + [(CONTENT_TYPE, JSON_MIME_TYPE)], + Json(json!({ + "jsonrpc": "2.0", + "id": 1, + "error": { + "code": -32603, + "message": "transient json error", + }, + })), + ) + } +} diff --git a/codex-rs/rmcp-client/src/rmcp_client.rs b/codex-rs/rmcp-client/src/rmcp_client.rs index b1e4c3dbd38..bd6fb7b76f4 100644 --- a/codex-rs/rmcp-client/src/rmcp_client.rs +++ b/codex-rs/rmcp-client/src/rmcp_client.rs @@ -277,6 +277,19 @@ where } } +async fn sleep_with_retry_deadline(delay: Duration, deadline: Option) -> bool { + if let Some(deadline) = deadline { + let remaining = deadline.saturating_duration_since(Instant::now()); + if remaining.is_zero() { + return false; + } + time::timeout(remaining, time::sleep(delay)).await.is_ok() + } else { + time::sleep(delay).await; + true + } +} + #[derive(Debug, thiserror::Error)] enum ClientOperationError { #[error(transparent)] @@ -942,6 +955,7 @@ impl RmcpClient { )> { let should_retry = initial_transport.is_streamable_http(); let metric_transport = initial_transport.metric_transport(); + let retry_deadline = timeout.map(|duration| Instant::now() + duration); let mut pending_transport = Some(initial_transport); let retry_schedule = STREAMABLE_HTTP_RETRY_DELAYS_MS @@ -952,6 +966,24 @@ impl RmcpClient { for (attempt, retry_delay_ms) in retry_schedule.enumerate() { let attempt_count = attempt + 1; + let attempt_timeout = + retry_deadline.map(|deadline| deadline.saturating_duration_since(Instant::now())); + if let Some(remaining) = attempt_timeout + && remaining.is_zero() + { + emit_initialize_metric( + metric_transport, + "error", + attempt_count, + /*retry_exhausted*/ false, + "timeout", + ); + let duration = timeout.unwrap_or(remaining); + return Err(anyhow!( + "timed out handshaking with MCP server after {duration:?}" + )); + } + let transport = match pending_transport.take() { Some(transport) => transport, None => match Self::create_pending_transport(&self.transport_recipe).await { @@ -969,7 +1001,12 @@ impl RmcpClient { }, }; - match Self::connect_pending_transport(transport, client_service.clone(), timeout).await + match Self::connect_pending_transport( + transport, + client_service.clone(), + attempt_timeout, + ) + .await { Ok(result) => { emit_initialize_metric( @@ -1000,7 +1037,19 @@ impl RmcpClient { error = %error, "streamable HTTP MCP initialize failed with a retryable error; retrying" ); - time::sleep(delay).await; + if !sleep_with_retry_deadline(delay, retry_deadline).await { + emit_initialize_metric( + metric_transport, + "error", + attempt_count, + /*retry_exhausted*/ false, + "timeout", + ); + let duration = timeout.unwrap_or(delay); + return Err(anyhow!( + "timed out handshaking with MCP server after {duration:?}" + )); + } } Err(error) => { emit_initialize_metric( @@ -1030,13 +1079,27 @@ impl RmcpClient { { let mut session_recovery_attempted = false; let mut retry_attempt = 0; + let retry_deadline = timeout.map(|duration| Instant::now() + duration); loop { let service = self.service().await?; + let attempt_timeout = + retry_deadline.map(|deadline| deadline.saturating_duration_since(Instant::now())); + if let Some(remaining) = attempt_timeout + && remaining.is_zero() + { + let duration = timeout.unwrap_or(remaining); + return Err(ClientOperationError::Timeout { + label: label.to_string(), + duration, + } + .into()); + } + match Self::run_service_operation_once( Arc::clone(&service), label, - timeout, + attempt_timeout, self.elicitation_pause_state.clone(), &operation, ) @@ -1063,7 +1126,14 @@ impl RmcpClient { error = %error, "MCP service operation failed with a retryable error; retrying" ); - time::sleep(delay).await; + if !sleep_with_retry_deadline(delay, retry_deadline).await { + let duration = timeout.unwrap_or(delay); + return Err(ClientOperationError::Timeout { + label: label.to_string(), + duration, + } + .into()); + } } Err(error) => return Err(error.into()), } @@ -1301,73 +1371,5 @@ async fn create_oauth_transport_and_runtime( } #[cfg(test)] -mod tests { - use std::collections::BTreeMap; - use std::time::Duration; - - use pretty_assertions::assert_eq; - use tokio::time; - - use super::*; - - fn metric_tags_map(tags: Vec<(&'static str, String)>) -> BTreeMap<&'static str, String> { - tags.into_iter().collect() - } - - #[tokio::test] - async fn active_time_timeout_pauses_while_elicitation_is_pending() { - let pause_state = ElicitationPauseState::new(); - let pause = pause_state.enter(); - tokio::spawn(async move { - time::sleep(Duration::from_millis(75)).await; - drop(pause); - }); - - let result = - active_time_timeout(Duration::from_millis(50), pause_state.subscribe(), async { - time::sleep(Duration::from_millis(90)).await; - "done" - }) - .await; - - assert_eq!(Ok("done"), result); - } - - #[test] - fn initialize_metric_tags_record_success_after_retry() { - let tags = metric_tags_map(initialize_metric_tags( - "streamable_http", - "success", - /*attempts*/ 2, - /*retry_exhausted*/ false, - "none", - )); - - assert_eq!(tags["transport"], "streamable_http"); - assert_eq!(tags["outcome"], "success"); - assert_eq!(tags["retried"], "true"); - assert_eq!(tags["attempts"], "2"); - assert_eq!(tags["retry_count"], "1"); - assert_eq!(tags["retry_exhausted"], "false"); - assert_eq!(tags["failure_kind"], "none"); - } - - #[test] - fn initialize_metric_tags_record_retry_exhaustion() { - let tags = metric_tags_map(initialize_metric_tags( - "streamable_http", - "error", - /*attempts*/ 3, - /*retry_exhausted*/ true, - "retry_exhausted", - )); - - assert_eq!(tags["transport"], "streamable_http"); - assert_eq!(tags["outcome"], "error"); - assert_eq!(tags["retried"], "true"); - assert_eq!(tags["attempts"], "3"); - assert_eq!(tags["retry_count"], "2"); - assert_eq!(tags["retry_exhausted"], "true"); - assert_eq!(tags["failure_kind"], "retry_exhausted"); - } -} +#[path = "rmcp_client_tests.rs"] +mod tests; diff --git a/codex-rs/rmcp-client/src/rmcp_client_tests.rs b/codex-rs/rmcp-client/src/rmcp_client_tests.rs new file mode 100644 index 00000000000..df13e73c798 --- /dev/null +++ b/codex-rs/rmcp-client/src/rmcp_client_tests.rs @@ -0,0 +1,67 @@ +use std::collections::BTreeMap; +use std::time::Duration; + +use pretty_assertions::assert_eq; +use tokio::time; + +use super::*; + +fn metric_tags_map(tags: Vec<(&'static str, String)>) -> BTreeMap<&'static str, String> { + tags.into_iter().collect() +} + +#[tokio::test] +async fn active_time_timeout_pauses_while_elicitation_is_pending() { + let pause_state = ElicitationPauseState::new(); + let pause = pause_state.enter(); + tokio::spawn(async move { + time::sleep(Duration::from_millis(75)).await; + drop(pause); + }); + + let result = active_time_timeout(Duration::from_millis(50), pause_state.subscribe(), async { + time::sleep(Duration::from_millis(90)).await; + "done" + }) + .await; + + assert_eq!(Ok("done"), result); +} + +#[test] +fn initialize_metric_tags_record_success_after_retry() { + let tags = metric_tags_map(initialize_metric_tags( + "streamable_http", + "success", + /*attempts*/ 2, + /*retry_exhausted*/ false, + "none", + )); + + assert_eq!(tags["transport"], "streamable_http"); + assert_eq!(tags["outcome"], "success"); + assert_eq!(tags["retried"], "true"); + assert_eq!(tags["attempts"], "2"); + assert_eq!(tags["retry_count"], "1"); + assert_eq!(tags["retry_exhausted"], "false"); + assert_eq!(tags["failure_kind"], "none"); +} + +#[test] +fn initialize_metric_tags_record_retry_exhaustion() { + let tags = metric_tags_map(initialize_metric_tags( + "streamable_http", + "error", + /*attempts*/ 3, + /*retry_exhausted*/ true, + "retry_exhausted", + )); + + assert_eq!(tags["transport"], "streamable_http"); + assert_eq!(tags["outcome"], "error"); + assert_eq!(tags["retried"], "true"); + assert_eq!(tags["attempts"], "3"); + assert_eq!(tags["retry_count"], "2"); + assert_eq!(tags["retry_exhausted"], "true"); + assert_eq!(tags["failure_kind"], "retry_exhausted"); +} diff --git a/codex-rs/rmcp-client/tests/streamable_http_recovery.rs b/codex-rs/rmcp-client/tests/streamable_http_recovery.rs index ceb10508797..265efa671db 100644 --- a/codex-rs/rmcp-client/tests/streamable_http_recovery.rs +++ b/codex-rs/rmcp-client/tests/streamable_http_recovery.rs @@ -4,17 +4,23 @@ use std::sync::Arc; use std::sync::atomic::AtomicUsize; use std::sync::atomic::Ordering; use std::time::Duration; +use std::time::Instant; +use codex_config::types::OAuthCredentialsStoreMode; use codex_exec_server::Environment; use codex_exec_server::ExecServerError; use codex_exec_server::HttpClient; use codex_exec_server::HttpRequestParams; use codex_exec_server::HttpRequestResponse; use codex_exec_server::HttpResponseBodyStream; +use codex_rmcp_client::ElicitationAction; +use codex_rmcp_client::ElicitationResponse; +use codex_rmcp_client::RmcpClient; use futures::FutureExt as _; use futures::future::BoxFuture; use pretty_assertions::assert_eq; use serde_json::Value; +use serde_json::json; use streamable_http_test_support::arm_initialize_failure; use streamable_http_test_support::arm_session_post_failure; @@ -22,6 +28,7 @@ use streamable_http_test_support::call_echo_tool; use streamable_http_test_support::create_client; use streamable_http_test_support::create_client_with_http_client; use streamable_http_test_support::expected_echo_result; +use streamable_http_test_support::init_params; use streamable_http_test_support::spawn_streamable_http_server; const JSON_RPC_INTERNAL_ERROR_CODE: i64 = -32603; @@ -134,6 +141,57 @@ async fn streamable_http_initialize_retries_retryable_status() -> anyhow::Result Ok(()) } +#[tokio::test(flavor = "multi_thread", worker_threads = 1)] +async fn streamable_http_initialize_retry_sleep_respects_startup_timeout() -> anyhow::Result<()> { + let (_server, base_url) = spawn_streamable_http_server().await?; + arm_initialize_failure(&base_url, /*status*/ 503, /*remaining*/ 1).await?; + + let client = RmcpClient::new_streamable_http_client( + "test-streamable-http", + &format!("{base_url}/mcp"), + Some("test-bearer".to_string()), + /*http_headers*/ None, + /*env_http_headers*/ None, + OAuthCredentialsStoreMode::File, + Environment::default_for_tests().get_http_client(), + /*auth_provider*/ None, + ) + .await?; + + let started = Instant::now(); + let error = client + .initialize( + init_params(), + Some(Duration::from_millis(100)), + Box::new(|_, _| { + async { + Ok(ElicitationResponse { + action: ElicitationAction::Accept, + content: Some(json!({})), + meta: None, + }) + } + .boxed() + }), + ) + .await + .unwrap_err(); + + let elapsed = started.elapsed(); + assert!( + elapsed < Duration::from_millis(500), + "initialize retry exceeded startup timeout budget: {elapsed:?}" + ); + assert!( + error + .to_string() + .contains("timed out handshaking with MCP server"), + "expected handshake timeout, got: {error:#}" + ); + + Ok(()) +} + #[tokio::test(flavor = "multi_thread", worker_threads = 1)] async fn streamable_http_initialize_retries_http_request_error() -> anyhow::Result<()> { let (_server, base_url) = spawn_streamable_http_server().await?; @@ -177,6 +235,9 @@ async fn streamable_http_initialize_retries_remote_http_request_error() -> anyho async fn streamable_http_tools_list_retries_retryable_status() -> anyhow::Result<()> { let (_server, base_url) = spawn_streamable_http_server().await?; let client = create_client(&base_url).await?; + let expected_tools = client + .list_tools(/*params*/ None, Some(Duration::from_secs(5))) + .await?; arm_session_post_failure( &base_url, @@ -190,8 +251,39 @@ async fn streamable_http_tools_list_retries_retryable_status() -> anyhow::Result .list_tools(/*params*/ None, Some(Duration::from_secs(5))) .await?; - assert_eq!(tools.tools.len(), 1); - assert_eq!(tools.tools[0].name, "echo"); + assert_eq!(tools, expected_tools); + + Ok(()) +} + +#[tokio::test(flavor = "multi_thread", worker_threads = 1)] +async fn streamable_http_tools_list_retry_sleep_respects_operation_timeout() -> anyhow::Result<()> { + let (_server, base_url) = spawn_streamable_http_server().await?; + let client = create_client(&base_url).await?; + + arm_session_post_failure( + &base_url, + /*status*/ 503, + /*remaining*/ 1, + /*www_authenticate_headers*/ &[], + ) + .await?; + + let started = Instant::now(); + let error = client + .list_tools(/*params*/ None, Some(Duration::from_millis(100))) + .await + .unwrap_err(); + + let elapsed = started.elapsed(); + assert!( + elapsed < Duration::from_millis(500), + "tools/list retry exceeded operation timeout budget: {elapsed:?}" + ); + assert!( + error.to_string().contains("timed out awaiting tools/list"), + "expected tools/list timeout, got: {error:#}" + ); Ok(()) } @@ -199,6 +291,10 @@ async fn streamable_http_tools_list_retries_retryable_status() -> anyhow::Result #[tokio::test(flavor = "multi_thread", worker_threads = 1)] async fn streamable_http_tools_list_retries_http_request_error() -> anyhow::Result<()> { let (_server, base_url) = spawn_streamable_http_server().await?; + let baseline_client = create_client(&base_url).await?; + let expected_tools = baseline_client + .list_tools(/*params*/ None, Some(Duration::from_secs(5))) + .await?; let http_client = FailFirstMethodHttpClient::new( Environment::default_for_tests().get_http_client(), "tools/list", @@ -211,8 +307,7 @@ async fn streamable_http_tools_list_retries_http_request_error() -> anyhow::Resu .await?; assert_eq!(http_client.matching_post_attempts(), 2); - assert_eq!(tools.tools.len(), 1); - assert_eq!(tools.tools[0].name, "echo"); + assert_eq!(tools, expected_tools); Ok(()) } @@ -220,6 +315,10 @@ async fn streamable_http_tools_list_retries_http_request_error() -> anyhow::Resu #[tokio::test(flavor = "multi_thread", worker_threads = 1)] async fn streamable_http_tools_list_retries_remote_http_request_error() -> anyhow::Result<()> { let (_server, base_url) = spawn_streamable_http_server().await?; + let baseline_client = create_client(&base_url).await?; + let expected_tools = baseline_client + .list_tools(/*params*/ None, Some(Duration::from_secs(5))) + .await?; let http_client = FailFirstMethodHttpClient::new( Environment::default_for_tests().get_http_client(), "tools/list", @@ -232,8 +331,7 @@ async fn streamable_http_tools_list_retries_remote_http_request_error() -> anyho .await?; assert_eq!(http_client.matching_post_attempts(), 2); - assert_eq!(tools.tools.len(), 1); - assert_eq!(tools.tools[0].name, "echo"); + assert_eq!(tools, expected_tools); Ok(()) } diff --git a/codex-rs/rmcp-client/tests/streamable_http_test_support.rs b/codex-rs/rmcp-client/tests/streamable_http_test_support.rs index 324bf06b2a2..cdb2286a1d0 100644 --- a/codex-rs/rmcp-client/tests/streamable_http_test_support.rs +++ b/codex-rs/rmcp-client/tests/streamable_http_test_support.rs @@ -51,7 +51,7 @@ fn streamable_http_server_bin() -> Result { codex_utils_cargo_bin::cargo_bin("test_streamable_http_server") } -fn init_params() -> InitializeRequestParams { +pub(crate) fn init_params() -> InitializeRequestParams { let mut capabilities = ClientCapabilities::default(); capabilities.elicitation = Some(ElicitationCapability { form: Some(FormElicitationCapability { From a7ddbf9a88fb4e4442a96a37849bd60846afc5ec Mon Sep 17 00:00:00 2001 From: Sama Setty Date: Thu, 4 Jun 2026 18:08:17 -0700 Subject: [PATCH 07/15] codex: address follow-up review feedback (#25147) --- codex-rs/rmcp-client/src/rmcp_client_tests.rs | 38 +- justfile | 1 - sdk/python/src/openai_codex/api.py | 4 + .../generated/notification_registry.py | 2 + .../src/openai_codex/generated/v2_all.py | 486 +++++++++++++++--- .../test_artifact_workflow_and_binaries.py | 13 +- sdk/python/tests/test_contract_generation.py | 2 +- 7 files changed, 447 insertions(+), 99 deletions(-) diff --git a/codex-rs/rmcp-client/src/rmcp_client_tests.rs b/codex-rs/rmcp-client/src/rmcp_client_tests.rs index df13e73c798..21f0cc62dd2 100644 --- a/codex-rs/rmcp-client/src/rmcp_client_tests.rs +++ b/codex-rs/rmcp-client/src/rmcp_client_tests.rs @@ -38,13 +38,18 @@ fn initialize_metric_tags_record_success_after_retry() { "none", )); - assert_eq!(tags["transport"], "streamable_http"); - assert_eq!(tags["outcome"], "success"); - assert_eq!(tags["retried"], "true"); - assert_eq!(tags["attempts"], "2"); - assert_eq!(tags["retry_count"], "1"); - assert_eq!(tags["retry_exhausted"], "false"); - assert_eq!(tags["failure_kind"], "none"); + assert_eq!( + tags, + metric_tags_map(vec![ + ("transport", "streamable_http".to_string()), + ("outcome", "success".to_string()), + ("retried", "true".to_string()), + ("attempts", "2".to_string()), + ("retry_count", "1".to_string()), + ("retry_exhausted", "false".to_string()), + ("failure_kind", "none".to_string()), + ]) + ); } #[test] @@ -57,11 +62,16 @@ fn initialize_metric_tags_record_retry_exhaustion() { "retry_exhausted", )); - assert_eq!(tags["transport"], "streamable_http"); - assert_eq!(tags["outcome"], "error"); - assert_eq!(tags["retried"], "true"); - assert_eq!(tags["attempts"], "3"); - assert_eq!(tags["retry_count"], "2"); - assert_eq!(tags["retry_exhausted"], "true"); - assert_eq!(tags["failure_kind"], "retry_exhausted"); + assert_eq!( + tags, + metric_tags_map(vec![ + ("transport", "streamable_http".to_string()), + ("outcome", "error".to_string()), + ("retried", "true".to_string()), + ("attempts", "3".to_string()), + ("retry_count", "2".to_string()), + ("retry_exhausted", "true".to_string()), + ("failure_kind", "retry_exhausted".to_string()), + ]) + ); } diff --git a/justfile b/justfile index a088978ac8b..ec5df98405a 100644 --- a/justfile +++ b/justfile @@ -36,7 +36,6 @@ fmt: uv run --frozen --project ../sdk/python --extra dev ruff check --fix --fix-only ../sdk/python uv run --frozen --project ../sdk/python --extra dev ruff format ../sdk/python -# Check formatting without modifying files. fmt-check: cargo fmt -- --check --config imports_granularity=Item 2>/dev/null uv run --frozen --project ../sdk/python --extra dev ruff check --diff ../sdk/python diff --git a/sdk/python/src/openai_codex/api.py b/sdk/python/src/openai_codex/api.py index 6fc9a8243d6..96420433563 100644 --- a/sdk/python/src/openai_codex/api.py +++ b/sdk/python/src/openai_codex/api.py @@ -576,6 +576,7 @@ def turn( input: RunInput, *, approval_mode: ApprovalMode | None = None, + client_user_message_id: str | None = None, cwd: str | None = None, effort: ReasoningEffort | None = None, model: str | None = None, @@ -593,6 +594,7 @@ def turn( input=wire_input, approval_policy=approval_policy, approvals_reviewer=approvals_reviewer, + client_user_message_id=client_user_message_id, cwd=cwd, effort=effort, model=model, @@ -664,6 +666,7 @@ async def turn( input: RunInput, *, approval_mode: ApprovalMode | None = None, + client_user_message_id: str | None = None, cwd: str | None = None, effort: ReasoningEffort | None = None, model: str | None = None, @@ -682,6 +685,7 @@ async def turn( input=wire_input, approval_policy=approval_policy, approvals_reviewer=approvals_reviewer, + client_user_message_id=client_user_message_id, cwd=cwd, effort=effort, model=model, diff --git a/sdk/python/src/openai_codex/generated/notification_registry.py b/sdk/python/src/openai_codex/generated/notification_registry.py index c55eb5b9b75..d5e620a7a3c 100644 --- a/sdk/python/src/openai_codex/generated/notification_registry.py +++ b/sdk/python/src/openai_codex/generated/notification_registry.py @@ -57,6 +57,7 @@ from .v2_all import ThreadRealtimeStartedNotification from .v2_all import ThreadRealtimeTranscriptDeltaNotification from .v2_all import ThreadRealtimeTranscriptDoneNotification +from .v2_all import ThreadSettingsUpdatedNotification from .v2_all import ThreadStartedNotification from .v2_all import ThreadStatusChangedNotification from .v2_all import ThreadTokenUsageUpdatedNotification @@ -122,6 +123,7 @@ "thread/realtime/started": ThreadRealtimeStartedNotification, "thread/realtime/transcript/delta": ThreadRealtimeTranscriptDeltaNotification, "thread/realtime/transcript/done": ThreadRealtimeTranscriptDoneNotification, + "thread/settings/updated": ThreadSettingsUpdatedNotification, "thread/started": ThreadStartedNotification, "thread/status/changed": ThreadStatusChangedNotification, "thread/tokenUsage/updated": ThreadTokenUsageUpdatedNotification, diff --git a/sdk/python/src/openai_codex/generated/v2_all.py b/sdk/python/src/openai_codex/generated/v2_all.py index 85120b82932..15ede1801cf 100644 --- a/sdk/python/src/openai_codex/generated/v2_all.py +++ b/sdk/python/src/openai_codex/generated/v2_all.py @@ -56,7 +56,7 @@ class ActivePermissionProfile(BaseModel): extends: Annotated[ str | None, Field( - description="Parent profile identifier once permissions profiles support inheritance. This is currently always `null`." + description="Parent profile identifier from the selected permissions profile's `extends` setting, when present." ), ] = None id: Annotated[ @@ -77,6 +77,11 @@ class AddCreditsNudgeEmailStatus(Enum): cooldown_active = "cooldown_active" +class AdditionalContextKind(Enum): + untrusted = "untrusted" + application = "application" + + class AdditionalNetworkPermissions(BaseModel): model_config = ConfigDict( populate_by_name=True, @@ -251,6 +256,11 @@ class AuthMode(Enum): agent_identity = "agentIdentity" +class AutoCompactTokenLimitScope(Enum): + total = "total" + body_after_prefix = "body_after_prefix" + + class AutoReviewDecisionSource(RootModel[Literal["agent"]]): model_config = ConfigDict( populate_by_name=True, @@ -563,6 +573,13 @@ class CommandMigration(BaseModel): name: str +class ComputerUseRequirements(BaseModel): + model_config = ConfigDict( + populate_by_name=True, + ) + allow_locked_computer_use: Annotated[bool | None, Field(alias="allowLockedComputerUse")] = None + + class MdmConfigLayerSource(BaseModel): model_config = ConfigDict( populate_by_name=True, @@ -585,6 +602,22 @@ class SystemConfigLayerSource(BaseModel): type: Annotated[Literal["system"], Field(title="SystemConfigLayerSourceType")] +class EnterpriseManagedConfigLayerSource(BaseModel): + model_config = ConfigDict( + populate_by_name=True, + ) + id: Annotated[str, Field(description="Stable identifier for the delivered layer.")] + name: Annotated[ + str, + Field( + description="Admin-facing name for the delivered layer. This is surfaced in diagnostics so users know which cloud layer needs administrator attention." + ), + ] + type: Annotated[ + Literal["enterpriseManaged"], Field(title="EnterpriseManagedConfigLayerSourceType") + ] + + class UserConfigLayerSource(BaseModel): model_config = ConfigDict( populate_by_name=True, @@ -644,6 +677,7 @@ class ConfigLayerSource( RootModel[ MdmConfigLayerSource | SystemConfigLayerSource + | EnterpriseManagedConfigLayerSource | UserConfigLayerSource | ProjectConfigLayerSource | SessionFlagsConfigLayerSource @@ -657,6 +691,7 @@ class ConfigLayerSource( root: ( MdmConfigLayerSource | SystemConfigLayerSource + | EnterpriseManagedConfigLayerSource | UserConfigLayerSource | ProjectConfigLayerSource | SessionFlagsConfigLayerSource @@ -675,7 +710,7 @@ class ConfigReadParams(BaseModel): description="Optional working directory to resolve project config layers. If specified, return the effective config as seen from that directory (i.e., including any project layers between `cwd` and the project/repo root)." ), ] = None - include_layers: Annotated[bool | None, Field(alias="includeLayers")] = False + include_layers: Annotated[bool | None, Field(alias="includeLayers")] = None class CommandConfiguredHookHandler(BaseModel): @@ -913,7 +948,7 @@ class FeedbackUploadParams(BaseModel): ) classification: str extra_log_files: Annotated[list[str] | None, Field(alias="extraLogFiles")] = None - include_logs: Annotated[bool, Field(alias="includeLogs")] + include_logs: Annotated[bool | None, Field(alias="includeLogs")] = None reason: str | None = None tags: dict[str, Any] | None = None thread_id: Annotated[str | None, Field(alias="threadId")] = None @@ -939,7 +974,7 @@ class FileChangeOutputDeltaNotification(BaseModel): class FileSystemAccessMode(Enum): read = "read" write = "write" - none = "none" + deny = "deny" class PathFileSystemPath(BaseModel): @@ -1276,6 +1311,17 @@ class InputTextFunctionCallOutputContentItem(BaseModel): ] +class EncryptedContentFunctionCallOutputContentItem(BaseModel): + model_config = ConfigDict( + populate_by_name=True, + ) + encrypted_content: str + type: Annotated[ + Literal["encrypted_content"], + Field(title="EncryptedContentFunctionCallOutputContentItemType"), + ] + + class FuzzyFileSearchMatchType(Enum): file = "file" directory = "directory" @@ -1335,7 +1381,7 @@ class GetAccountParams(BaseModel): alias="refreshToken", description="When `true`, requests a proactive token refresh before returning.\n\nIn managed auth mode this triggers the normal refresh-token flow. In external auth mode this flag is ignored. Clients should refresh tokens themselves and call `account/login/start` with `chatgptAuthTokens`.", ), - ] = False + ] = None class GitInfo(BaseModel): @@ -1425,6 +1471,8 @@ class HookEventName(Enum): post_compact = "postCompact" session_start = "sessionStart" user_prompt_submit = "userPromptSubmit" + subagent_start = "subagentStart" + subagent_stop = "subagentStop" stop = "stop" @@ -1483,6 +1531,7 @@ class HookSource(Enum): session_flags = "sessionFlags" plugin = "plugin" cloud_requirements = "cloudRequirements" + cloud_managed_config = "cloudManagedConfig" legacy_managed_config_file = "legacyManagedConfigFile" legacy_managed_config_mdm = "legacyManagedConfigMdm" unknown = "unknown" @@ -1506,6 +1555,8 @@ class HooksListParams(BaseModel): class ImageDetail(Enum): + auto = "auto" + low = "low" high = "high" original = "original" @@ -1742,6 +1793,8 @@ class ManagedHooksRequirements(BaseModel): pre_tool_use: Annotated[list[ConfiguredHookMatcherGroup], Field(alias="PreToolUse")] session_start: Annotated[list[ConfiguredHookMatcherGroup], Field(alias="SessionStart")] stop: Annotated[list[ConfiguredHookMatcherGroup], Field(alias="Stop")] + subagent_start: Annotated[list[ConfiguredHookMatcherGroup], Field(alias="SubagentStart")] + subagent_stop: Annotated[list[ConfiguredHookMatcherGroup], Field(alias="SubagentStop")] user_prompt_submit: Annotated[list[ConfiguredHookMatcherGroup], Field(alias="UserPromptSubmit")] managed_dir: Annotated[str | None, Field(alias="managedDir")] = None windows_managed_dir: Annotated[str | None, Field(alias="windowsManagedDir")] = None @@ -1835,6 +1888,18 @@ class McpResourceReadParams(BaseModel): uri: str +class McpServerInfo(BaseModel): + model_config = ConfigDict( + populate_by_name=True, + ) + description: str | None = None + icons: list | None = None + name: str + title: str | None = None + version: str + website_url: Annotated[str | None, Field(alias="websiteUrl")] = None + + class McpServerMigration(BaseModel): model_config = ConfigDict( populate_by_name=True, @@ -2142,7 +2207,7 @@ class NetworkRequirements(BaseModel): class NetworkUnixSocketPermission(Enum): allow = "allow" - none = "none" + deny = "deny" class NonSteerableTurnKind(Enum): @@ -2188,6 +2253,32 @@ class PatchChangeKind( root: AddPatchChangeKind | DeletePatchChangeKind | UpdatePatchChangeKind +class PermissionProfileListParams(BaseModel): + model_config = ConfigDict( + populate_by_name=True, + ) + cursor: Annotated[ + str | None, Field(description="Opaque pagination cursor returned by a previous call.") + ] = None + cwd: Annotated[ + str | None, + Field(description="Optional working directory to resolve project config layers."), + ] = None + limit: Annotated[ + int | None, Field(description="Optional page size; defaults to the full result set.", ge=0) + ] = None + + +class PermissionProfileSummary(BaseModel): + model_config = ConfigDict( + populate_by_name=True, + ) + description: Annotated[ + str | None, Field(description="Optional user-facing description for display in clients.") + ] = None + id: Annotated[str, Field(description="Available permission profile identifier.")] + + class Personality(Enum): none = "none" friendly = "friendly" @@ -2332,6 +2423,7 @@ class PluginInterface(BaseModel): class PluginListMarketplaceKind(Enum): local = "local" + vertical = "vertical" workspace_directory = "workspace-directory" shared_with_me = "shared-with-me" @@ -3469,6 +3561,20 @@ class SkillsConfigWriteResponse(BaseModel): effective_enabled: Annotated[bool, Field(alias="effectiveEnabled")] +class SkillsExtraRootsSetParams(BaseModel): + model_config = ConfigDict( + populate_by_name=True, + ) + extra_roots: Annotated[list[AbsolutePathBuf], Field(alias="extraRoots")] + + +class SkillsExtraRootsSetResponse(BaseModel): + pass + model_config = ConfigDict( + populate_by_name=True, + ) + + class SkillsListParams(BaseModel): model_config = ConfigDict( populate_by_name=True, @@ -3491,6 +3597,16 @@ class SortDirection(Enum): desc = "desc" +class SpendControlLimitSnapshot(BaseModel): + model_config = ConfigDict( + populate_by_name=True, + ) + limit: str + remaining_percent: Annotated[int, Field(alias="remainingPercent")] + resets_at: Annotated[int, Field(alias="resetsAt")] + used: str + + class SubAgentSourceValue(Enum): review = "review" compact = "compact" @@ -3624,6 +3740,20 @@ class ThreadCompactStartResponse(BaseModel): ) +class ThreadGoalClearParams(BaseModel): + model_config = ConfigDict( + populate_by_name=True, + ) + thread_id: Annotated[str, Field(alias="threadId")] + + +class ThreadGoalClearResponse(BaseModel): + model_config = ConfigDict( + populate_by_name=True, + ) + cleared: bool + + class ThreadGoalClearedNotification(BaseModel): model_config = ConfigDict( populate_by_name=True, @@ -3631,6 +3761,13 @@ class ThreadGoalClearedNotification(BaseModel): thread_id: Annotated[str, Field(alias="threadId")] +class ThreadGoalGetParams(BaseModel): + model_config = ConfigDict( + populate_by_name=True, + ) + thread_id: Annotated[str, Field(alias="threadId")] + + class ThreadGoalStatus(Enum): active = "active" paused = "paused" @@ -3749,6 +3886,7 @@ class McpToolCallThreadItem(BaseModel): error: McpToolCallError | None = None id: str mcp_app_resource_uri: Annotated[str | None, Field(alias="mcpAppResourceUri")] = None + plugin_id: Annotated[str | None, Field(alias="pluginId")] = None result: McpToolCallResult | None = None server: str status: McpToolCallStatus @@ -3922,7 +4060,7 @@ class ThreadReadParams(BaseModel): alias="includeTurns", description="When true, include turns and their items from rollout history.", ), - ] = False + ] = None thread_id: Annotated[str, Field(alias="threadId")] @@ -4536,10 +4674,19 @@ class AccountUpdatedNotification(BaseModel): plan_type: Annotated[PlanType | None, Field(alias="planType")] = None +class AdditionalContextEntry(BaseModel): + model_config = ConfigDict( + populate_by_name=True, + ) + kind: AdditionalContextKind + value: str + + class AppConfig(BaseModel): model_config = ConfigDict( populate_by_name=True, ) + approvals_reviewer: ApprovalsReviewer | None = None default_tools_approval_mode: AppToolApproval | None = None default_tools_enabled: bool | None = None destructive_enabled: bool | None = None @@ -4629,6 +4776,24 @@ class ThreadNameSetRequest(BaseModel): params: ThreadSetNameParams +class ThreadGoalGetRequest(BaseModel): + model_config = ConfigDict( + populate_by_name=True, + ) + id: RequestId + method: Annotated[Literal["thread/goal/get"], Field(title="Thread/goal/getRequestMethod")] + params: ThreadGoalGetParams + + +class ThreadGoalClearRequest(BaseModel): + model_config = ConfigDict( + populate_by_name=True, + ) + id: RequestId + method: Annotated[Literal["thread/goal/clear"], Field(title="Thread/goal/clearRequestMethod")] + params: ThreadGoalClearParams + + class ThreadMetadataUpdateRequest(BaseModel): model_config = ConfigDict( populate_by_name=True, @@ -4730,6 +4895,17 @@ class SkillsListRequest(BaseModel): params: SkillsListParams +class SkillsExtraRootsSetRequest(BaseModel): + model_config = ConfigDict( + populate_by_name=True, + ) + id: RequestId + method: Annotated[ + Literal["skills/extraRoots/set"], Field(title="Skills/extraRoots/setRequestMethod") + ] + params: SkillsExtraRootsSetParams + + class HooksListRequest(BaseModel): model_config = ConfigDict( populate_by_name=True, @@ -4995,6 +5171,17 @@ class ExperimentalFeatureListRequest(BaseModel): params: ExperimentalFeatureListParams +class PermissionProfileListRequest(BaseModel): + model_config = ConfigDict( + populate_by_name=True, + ) + id: RequestId + method: Annotated[ + Literal["permissionProfile/list"], Field(title="PermissionProfile/listRequestMethod") + ] + params: PermissionProfileListParams + + class ExperimentalFeatureEnablementSetRequest(BaseModel): model_config = ConfigDict( populate_by_name=True, @@ -5431,16 +5618,22 @@ class ConfigRequirements(BaseModel): model_config = ConfigDict( populate_by_name=True, ) + allow_appshots: Annotated[bool | None, Field(alias="allowAppshots")] = None allow_managed_hooks_only: Annotated[bool | None, Field(alias="allowManagedHooksOnly")] = None allowed_approval_policies: Annotated[ list[AskForApproval] | None, Field(alias="allowedApprovalPolicies") ] = None + allowed_permissions: Annotated[list[str] | None, Field(alias="allowedPermissions")] = None allowed_sandbox_modes: Annotated[ list[SandboxMode] | None, Field(alias="allowedSandboxModes") ] = None allowed_web_search_modes: Annotated[ list[WebSearchMode] | None, Field(alias="allowedWebSearchModes") ] = None + allowed_windows_sandbox_implementations: Annotated[ + list[WindowsSandboxSetupMode] | None, Field(alias="allowedWindowsSandboxImplementations") + ] = None + computer_use: Annotated[ComputerUseRequirements | None, Field(alias="computerUse")] = None enforce_residency: Annotated[ResidencyRequirement | None, Field(alias="enforceResidency")] = ( None ) @@ -5608,13 +5801,19 @@ class InputImageFunctionCallOutputContentItem(BaseModel): class FunctionCallOutputContentItem( - RootModel[InputTextFunctionCallOutputContentItem | InputImageFunctionCallOutputContentItem] + RootModel[ + InputTextFunctionCallOutputContentItem + | InputImageFunctionCallOutputContentItem + | EncryptedContentFunctionCallOutputContentItem + ] ): model_config = ConfigDict( populate_by_name=True, ) root: Annotated[ - InputTextFunctionCallOutputContentItem | InputImageFunctionCallOutputContentItem, + InputTextFunctionCallOutputContentItem + | InputImageFunctionCallOutputContentItem + | EncryptedContentFunctionCallOutputContentItem, Field( description="Responses API compatible content items that can be returned by a tool call. This is a subset of ContentItem with the types we support as function call outputs." ), @@ -5767,6 +5966,7 @@ class ListMcpServerStatusParams(BaseModel): int | None, Field(description="Optional page size; defaults to a server-defined value.", ge=0), ] = None + thread_id: Annotated[str | None, Field(alias="threadId")] = None class McpResourceReadResponse(BaseModel): @@ -5784,6 +5984,7 @@ class McpServerStatus(BaseModel): name: str resource_templates: Annotated[list[ResourceTemplate], Field(alias="resourceTemplates")] resources: list[Resource] + server_info: Annotated[McpServerInfo | None, Field(alias="serverInfo")] = None tools: dict[str, Tool] @@ -5817,6 +6018,13 @@ class Model(BaseModel): ] = [] availability_nux: Annotated[ModelAvailabilityNux | None, Field(alias="availabilityNux")] = None default_reasoning_effort: Annotated[ReasoningEffort, Field(alias="defaultReasoningEffort")] + default_service_tier: Annotated[ + str | None, + Field( + alias="defaultServiceTier", + description="Catalog default service tier id for this model, when one is configured.", + ), + ] = None description: str display_name: Annotated[str, Field(alias="displayName")] hidden: bool @@ -5859,6 +6067,20 @@ class OverriddenMetadata(BaseModel): overriding_layer: Annotated[ConfigLayerMetadata, Field(alias="overridingLayer")] +class PermissionProfileListResponse(BaseModel): + model_config = ConfigDict( + populate_by_name=True, + ) + data: list[PermissionProfileSummary] + next_cursor: Annotated[ + str | None, + Field( + alias="nextCursor", + description="Opaque cursor to pass to the next call to continue after the last item. If None, there are no more items to return.", + ), + ] = None + + class PluginSharePrincipal(BaseModel): model_config = ConfigDict( populate_by_name=True, @@ -5926,6 +6148,9 @@ class RateLimitSnapshot(BaseModel): populate_by_name=True, ) credits: CreditsSnapshot | None = None + individual_limit: Annotated[ + SpendControlLimitSnapshot | None, Field(alias="individualLimit") + ] = None limit_id: Annotated[str | None, Field(alias="limitId")] = None limit_name: Annotated[str | None, Field(alias="limitName")] = None plan_type: Annotated[PlanType | None, Field(alias="planType")] = None @@ -6332,6 +6557,30 @@ class ThreadGoal(BaseModel): updated_at: Annotated[int, Field(alias="updatedAt")] +class ThreadGoalGetResponse(BaseModel): + model_config = ConfigDict( + populate_by_name=True, + ) + goal: ThreadGoal | None = None + + +class ThreadGoalSetParams(BaseModel): + model_config = ConfigDict( + populate_by_name=True, + ) + objective: str | None = None + status: ThreadGoalStatus | None = None + thread_id: Annotated[str, Field(alias="threadId")] + token_budget: Annotated[int | None, Field(alias="tokenBudget")] = None + + +class ThreadGoalSetResponse(BaseModel): + model_config = ConfigDict( + populate_by_name=True, + ) + goal: ThreadGoal + + class ThreadGoalUpdatedNotification(BaseModel): model_config = ConfigDict( populate_by_name=True, @@ -6345,6 +6594,7 @@ class UserMessageThreadItem(BaseModel): model_config = ConfigDict( populate_by_name=True, ) + client_id: Annotated[str | None, Field(alias="clientId")] = None content: list[UserInput] id: str type: Annotated[Literal["userMessage"], Field(title="UserMessageThreadItemType")] @@ -6536,6 +6786,55 @@ class ThreadListParams(BaseModel): ] = None +class ThreadResumeInitialTurnsPageParams(BaseModel): + model_config = ConfigDict( + populate_by_name=True, + ) + items_view: Annotated[ + TurnItemsView | None, + Field( + alias="itemsView", + description="How much item detail to include for each returned turn; defaults to summary.", + ), + ] = None + limit: Annotated[int | None, Field(description="Optional turn page size.", ge=0)] = None + sort_direction: Annotated[ + SortDirection | None, + Field( + alias="sortDirection", + description="Optional turn pagination direction; defaults to descending.", + ), + ] = None + + +class ThreadSettings(BaseModel): + model_config = ConfigDict( + populate_by_name=True, + ) + active_permission_profile: Annotated[ + ActivePermissionProfile | None, Field(alias="activePermissionProfile") + ] = None + approval_policy: Annotated[AskForApproval, Field(alias="approvalPolicy")] + approvals_reviewer: Annotated[ApprovalsReviewer, Field(alias="approvalsReviewer")] + collaboration_mode: Annotated[CollaborationMode, Field(alias="collaborationMode")] + cwd: AbsolutePathBuf + effort: ReasoningEffort | None = None + model: str + model_provider: Annotated[str, Field(alias="modelProvider")] + personality: Personality | None = None + sandbox_policy: Annotated[SandboxPolicy, Field(alias="sandboxPolicy")] + service_tier: Annotated[str | None, Field(alias="serviceTier")] = None + summary: ReasoningSummary | None = None + + +class ThreadSettingsUpdatedNotification(BaseModel): + model_config = ConfigDict( + populate_by_name=True, + ) + thread_id: Annotated[str, Field(alias="threadId")] + thread_settings: Annotated[ThreadSettings, Field(alias="threadSettings")] + + class ThreadStartParams(BaseModel): model_config = ConfigDict( populate_by_name=True, @@ -6648,6 +6947,7 @@ class TurnStartParams(BaseModel): description="Override where approval requests are routed for review on this turn and subsequent turns.", ), ] = None + client_user_message_id: Annotated[str | None, Field(alias="clientUserMessageId")] = None cwd: Annotated[ str | None, Field(description="Override the working directory for this turn and subsequent turns."), @@ -6696,6 +6996,7 @@ class TurnSteerParams(BaseModel): model_config = ConfigDict( populate_by_name=True, ) + client_user_message_id: Annotated[str | None, Field(alias="clientUserMessageId")] = None expected_turn_id: Annotated[ str, Field( @@ -6803,6 +7104,15 @@ class ThreadForkRequest(BaseModel): params: ThreadForkParams +class ThreadGoalSetRequest(BaseModel): + model_config = ConfigDict( + populate_by_name=True, + ) + id: RequestId + method: Annotated[Literal["thread/goal/set"], Field(title="Thread/goal/setRequestMethod")] + params: ThreadGoalSetParams + + class ThreadListRequest(BaseModel): model_config = ConfigDict( populate_by_name=True, @@ -6891,6 +7201,41 @@ class ConfigValueWriteRequest(BaseModel): params: ConfigValueWriteParams +class Config(BaseModel): + model_config = ConfigDict( + extra="allow", + populate_by_name=True, + ) + analytics: AnalyticsConfig | None = None + approval_policy: AskForApproval | None = None + approvals_reviewer: Annotated[ + ApprovalsReviewer | None, + Field( + description="[UNSTABLE] Optional default for where approval requests are routed for review." + ), + ] = None + compact_prompt: str | None = None + desktop: dict[str, Any] | None = None + developer_instructions: str | None = None + forced_chatgpt_workspace_id: ForcedChatgptWorkspaceIds | None = None + forced_login_method: ForcedLoginMethod | None = None + instructions: str | None = None + model: str | None = None + model_auto_compact_token_limit: int | None = None + model_auto_compact_token_limit_scope: AutoCompactTokenLimitScope | None = None + model_context_window: int | None = None + model_provider: str | None = None + model_reasoning_effort: ReasoningEffort | None = None + model_reasoning_summary: ReasoningSummary | None = None + model_verbosity: Verbosity | None = None + review_model: str | None = None + sandbox_mode: SandboxMode | None = None + sandbox_workspace_write: SandboxWorkspaceWrite | None = None + service_tier: str | None = None + tools: ToolsV2 | None = None + web_search: WebSearchMode | None = None + + class ConfigBatchWriteParams(BaseModel): model_config = ConfigDict( populate_by_name=True, @@ -6913,6 +7258,15 @@ class ConfigBatchWriteParams(BaseModel): ] = None +class ConfigReadResponse(BaseModel): + model_config = ConfigDict( + populate_by_name=True, + ) + config: Config + layers: list[ConfigLayer] | None = None + origins: dict[str, ConfigLayerMetadata] + + class ConfigWriteResponse(BaseModel): model_config = ConfigDict( populate_by_name=True, @@ -7115,29 +7469,6 @@ class PluginSummary(BaseModel): source: PluginSource -class ProfileV2(BaseModel): - model_config = ConfigDict( - extra="allow", - populate_by_name=True, - ) - approval_policy: AskForApproval | None = None - approvals_reviewer: Annotated[ - ApprovalsReviewer | None, - Field( - description="[UNSTABLE] Optional profile-level override for where approval requests are routed for review. If omitted, the enclosing config default is used." - ), - ] = None - chatgpt_base_url: str | None = None - model: str | None = None - model_provider: str | None = None - model_reasoning_effort: ReasoningEffort | None = None - model_reasoning_summary: ReasoningSummary | None = None - model_verbosity: Verbosity | None = None - service_tier: str | None = None - tools: ToolsV2 | None = None - web_search: WebSearchMode | None = None - - class RequestPermissionProfile(BaseModel): model_config = ConfigDict( extra="forbid", @@ -7229,6 +7560,16 @@ class ThreadGoalUpdatedServerNotification(BaseModel): params: ThreadGoalUpdatedNotification +class ThreadSettingsUpdatedServerNotification(BaseModel): + model_config = ConfigDict( + populate_by_name=True, + ) + method: Annotated[ + Literal["thread/settings/updated"], Field(title="Thread/settings/updatedNotificationMethod") + ] + params: ThreadSettingsUpdatedNotification + + class ThreadTokenUsageUpdatedServerNotification(BaseModel): model_config = ConfigDict( populate_by_name=True, @@ -7393,6 +7734,15 @@ class TurnStartedNotification(BaseModel): turn: Turn +class TurnsPage(BaseModel): + model_config = ConfigDict( + populate_by_name=True, + ) + backwards_cursor: Annotated[str | None, Field(alias="backwardsCursor")] = None + data: list[Turn] + next_cursor: Annotated[str | None, Field(alias="nextCursor")] = None + + class PluginShareSaveRequest(BaseModel): model_config = ConfigDict( populate_by_name=True, @@ -7411,51 +7761,6 @@ class ConfigBatchWriteRequest(BaseModel): params: ConfigBatchWriteParams -class Config(BaseModel): - model_config = ConfigDict( - extra="allow", - populate_by_name=True, - ) - analytics: AnalyticsConfig | None = None - approval_policy: AskForApproval | None = None - approvals_reviewer: Annotated[ - ApprovalsReviewer | None, - Field( - description="[UNSTABLE] Optional default for where approval requests are routed for review." - ), - ] = None - compact_prompt: str | None = None - desktop: dict[str, Any] | None = None - developer_instructions: str | None = None - forced_chatgpt_workspace_id: ForcedChatgptWorkspaceIds | None = None - forced_login_method: ForcedLoginMethod | None = None - instructions: str | None = None - model: str | None = None - model_auto_compact_token_limit: int | None = None - model_context_window: int | None = None - model_provider: str | None = None - model_reasoning_effort: ReasoningEffort | None = None - model_reasoning_summary: ReasoningSummary | None = None - model_verbosity: Verbosity | None = None - profile: str | None = None - profiles: dict[str, ProfileV2] | None = {} - review_model: str | None = None - sandbox_mode: SandboxMode | None = None - sandbox_workspace_write: SandboxWorkspaceWrite | None = None - service_tier: str | None = None - tools: ToolsV2 | None = None - web_search: WebSearchMode | None = None - - -class ConfigReadResponse(BaseModel): - model_config = ConfigDict( - populate_by_name=True, - ) - config: Config - layers: list[ConfigLayer] | None = None - origins: dict[str, ConfigLayerMetadata] - - class ExternalAgentConfigDetectResponse(BaseModel): model_config = ConfigDict( populate_by_name=True, @@ -7739,6 +8044,13 @@ class Thread(BaseModel): ), ] name: Annotated[str | None, Field(description="Optional user-facing thread title.")] = None + parent_thread_id: Annotated[ + str | None, + Field( + alias="parentThreadId", + description="The ID of the parent thread. This will only be set if this thread is a subagent.", + ), + ] = None path: Annotated[str | None, Field(description="[UNSTABLE] Path to the thread on disk.")] = None preview: Annotated[ str, Field(description="Usually the first user message in the thread, if available.") @@ -7892,6 +8204,14 @@ class ThreadRollbackResponse(BaseModel): ] +class ThreadSearchResult(BaseModel): + model_config = ConfigDict( + populate_by_name=True, + ) + snippet: str + thread: Thread + + class ThreadStartResponse(BaseModel): model_config = ConfigDict( populate_by_name=True, @@ -7960,6 +8280,9 @@ class ClientRequest( | ThreadArchiveRequest | ThreadUnsubscribeRequest | ThreadNameSetRequest + | ThreadGoalSetRequest + | ThreadGoalGetRequest + | ThreadGoalClearRequest | ThreadMetadataUpdateRequest | ThreadUnarchiveRequest | ThreadCompactStartRequest @@ -7971,6 +8294,7 @@ class ClientRequest( | ThreadReadRequest | ThreadInjectItemsRequest | SkillsListRequest + | SkillsExtraRootsSetRequest | HooksListRequest | MarketplaceAddRequest | MarketplaceRemoveRequest @@ -8004,6 +8328,7 @@ class ClientRequest( | ModelListRequest | ModelProviderCapabilitiesReadRequest | ExperimentalFeatureListRequest + | PermissionProfileListRequest | ExperimentalFeatureEnablementSetRequest | McpServerOauthLoginRequest | ConfigMcpServerReloadRequest @@ -8043,6 +8368,9 @@ class ClientRequest( | ThreadArchiveRequest | ThreadUnsubscribeRequest | ThreadNameSetRequest + | ThreadGoalSetRequest + | ThreadGoalGetRequest + | ThreadGoalClearRequest | ThreadMetadataUpdateRequest | ThreadUnarchiveRequest | ThreadCompactStartRequest @@ -8054,6 +8382,7 @@ class ClientRequest( | ThreadReadRequest | ThreadInjectItemsRequest | SkillsListRequest + | SkillsExtraRootsSetRequest | HooksListRequest | MarketplaceAddRequest | MarketplaceRemoveRequest @@ -8087,6 +8416,7 @@ class ClientRequest( | ModelListRequest | ModelProviderCapabilitiesReadRequest | ExperimentalFeatureListRequest + | PermissionProfileListRequest | ExperimentalFeatureEnablementSetRequest | McpServerOauthLoginRequest | ConfigMcpServerReloadRequest @@ -8158,6 +8488,7 @@ class ServerNotification( | ThreadNameUpdatedServerNotification | ThreadGoalUpdatedServerNotification | ThreadGoalClearedServerNotification + | ThreadSettingsUpdatedServerNotification | ThreadTokenUsageUpdatedServerNotification | TurnStartedServerNotification | HookStartedServerNotification @@ -8227,6 +8558,7 @@ class ServerNotification( | ThreadNameUpdatedServerNotification | ThreadGoalUpdatedServerNotification | ThreadGoalClearedServerNotification + | ThreadSettingsUpdatedServerNotification | ThreadTokenUsageUpdatedServerNotification | TurnStartedServerNotification | HookStartedServerNotification diff --git a/sdk/python/tests/test_artifact_workflow_and_binaries.py b/sdk/python/tests/test_artifact_workflow_and_binaries.py index d3f2f8d0ec1..bc57e4b5020 100644 --- a/sdk/python/tests/test_artifact_workflow_and_binaries.py +++ b/sdk/python/tests/test_artifact_workflow_and_binaries.py @@ -154,12 +154,13 @@ def test_schema_normalization_only_flattens_string_literal_oneofs( assert flattened == [ "MessagePhase", "TurnItemsView", - "PluginAvailability", "AuthMode", + "PluginAvailability", "InputModality", "ExperimentalFeatureStage", "ProcessOutputStream", "CommandExecOutputStream", + "AutoCompactTokenLimitScope", ] @@ -250,10 +251,10 @@ def test_source_sdk_template_pins_published_runtime() -> None: "dependencies": pyproject["project"]["dependencies"], } == { "sdk_template_version": "0.0.0-dev", - "runtime_pin": "0.132.0", + "runtime_pin": "0.137.0a4", "dependencies": [ "pydantic>=2.12", - "openai-codex-cli-bin==0.132.0", + "openai-codex-cli-bin==0.137.0a4", ], } @@ -328,7 +329,7 @@ def test_runtime_setup_reads_independent_runtime_pin_and_release_tags() -> None: } == { "package_name": "openai-codex-cli-bin", "sdk_template_version": "0.0.0-dev", - "runtime_pin": "0.132.0", + "runtime_pin": "0.137.0a4", "normalized_release_version": "0.116.0a1", "release_tag": "rust-v0.116.0-alpha.1", } @@ -543,7 +544,7 @@ def test_stage_sdk_release_preserves_reviewed_runtime_pin(tmp_path: Path) -> Non "version": "0.1.0b1", "dependencies": [ "pydantic>=2.12", - "openai-codex-cli-bin==0.132.0", + "openai-codex-cli-bin==0.137.0a4", ], } assert ( @@ -596,7 +597,7 @@ def test_sdk_beta_release_can_pin_stable_runtime(tmp_path: Path) -> None: "runtime_version": "0.132.0", "sdk_dependencies": [ "pydantic>=2.12", - "openai-codex-cli-bin==0.132.0", + "openai-codex-cli-bin==0.137.0a4", ], } diff --git a/sdk/python/tests/test_contract_generation.py b/sdk/python/tests/test_contract_generation.py index 2c95d06a3e7..36d37735c7a 100644 --- a/sdk/python/tests/test_contract_generation.py +++ b/sdk/python/tests/test_contract_generation.py @@ -40,7 +40,7 @@ def test_generated_files_are_up_to_date(): # Regenerate contract artifacts via the pinned runtime package, not a local # app-server binary from the checkout or CI environment. - assert importlib.metadata.version("openai-codex-cli-bin") == "0.132.0" + assert importlib.metadata.version("openai-codex-cli-bin") == "0.137.0a4" env = os.environ.copy() env.pop("CODEX_EXEC_PATH", None) python_bin = str(Path(sys.executable).parent) From a6f9512dbf2f7271ce70377859d0444f553b4954 Mon Sep 17 00:00:00 2001 From: Sama Setty Date: Thu, 4 Jun 2026 18:15:31 -0700 Subject: [PATCH 08/15] codex: fix python SDK signature test (#25147) --- sdk/python/src/openai_codex/api.py | 4 ++++ sdk/python/tests/test_public_api_signatures.py | 4 ++++ 2 files changed, 8 insertions(+) diff --git a/sdk/python/src/openai_codex/api.py b/sdk/python/src/openai_codex/api.py index 96420433563..560d053aaef 100644 --- a/sdk/python/src/openai_codex/api.py +++ b/sdk/python/src/openai_codex/api.py @@ -542,6 +542,7 @@ def run( input: RunInput, *, approval_mode: ApprovalMode | None = None, + client_user_message_id: str | None = None, cwd: str | None = None, effort: ReasoningEffort | None = None, model: str | None = None, @@ -555,6 +556,7 @@ def run( turn = self.turn( input, approval_mode=approval_mode, + client_user_message_id=client_user_message_id, cwd=cwd, effort=effort, model=model, @@ -632,6 +634,7 @@ async def run( input: RunInput, *, approval_mode: ApprovalMode | None = None, + client_user_message_id: str | None = None, cwd: str | None = None, effort: ReasoningEffort | None = None, model: str | None = None, @@ -645,6 +648,7 @@ async def run( turn = await self.turn( input, approval_mode=approval_mode, + client_user_message_id=client_user_message_id, cwd=cwd, effort=effort, model=model, diff --git a/sdk/python/tests/test_public_api_signatures.py b/sdk/python/tests/test_public_api_signatures.py index c26ac90f6b1..25a3b831d25 100644 --- a/sdk/python/tests/test_public_api_signatures.py +++ b/sdk/python/tests/test_public_api_signatures.py @@ -371,6 +371,7 @@ def test_generated_public_signatures_are_snake_case_and_typed() -> None: ], Thread.turn: [ "approval_mode", + "client_user_message_id", "cwd", "effort", "model", @@ -382,6 +383,7 @@ def test_generated_public_signatures_are_snake_case_and_typed() -> None: ], Thread.run: [ "approval_mode", + "client_user_message_id", "cwd", "effort", "model", @@ -446,6 +448,7 @@ def test_generated_public_signatures_are_snake_case_and_typed() -> None: ], AsyncThread.turn: [ "approval_mode", + "client_user_message_id", "cwd", "effort", "model", @@ -457,6 +460,7 @@ def test_generated_public_signatures_are_snake_case_and_typed() -> None: ], AsyncThread.run: [ "approval_mode", + "client_user_message_id", "cwd", "effort", "model", From c46094542376d8073516ff6e13075e82589ae283 Mon Sep 17 00:00:00 2001 From: Sama Setty Date: Thu, 4 Jun 2026 18:30:15 -0700 Subject: [PATCH 09/15] codex: narrow retry PR scope (#25147) --- .../rmcp-client/src/http_client_adapter.rs | 63 +++++++- codex-rs/rmcp-client/src/rmcp_client.rs | 5 +- codex-rs/rmcp-client/src/rmcp_client_tests.rs | 33 ++++ sdk/python/scripts/update_sdk_artifacts.py | 6 +- sdk/python/src/openai_codex/api.py | 8 - .../test_artifact_workflow_and_binaries.py | 147 +++++++++++++++--- .../tests/test_public_api_signatures.py | 4 - 7 files changed, 228 insertions(+), 38 deletions(-) diff --git a/codex-rs/rmcp-client/src/http_client_adapter.rs b/codex-rs/rmcp-client/src/http_client_adapter.rs index fbb8039659f..df7bee3a83b 100644 --- a/codex-rs/rmcp-client/src/http_client_adapter.rs +++ b/codex-rs/rmcp-client/src/http_client_adapter.rs @@ -29,7 +29,6 @@ use reqwest::header::HeaderMap; use reqwest::header::HeaderName; use rmcp::model::ClientJsonRpcMessage; use rmcp::model::JsonRpcMessage; -use rmcp::model::ServerJsonRpcMessage; use rmcp::transport::streamable_http_client::AuthRequiredError; use rmcp::transport::streamable_http_client::InsufficientScopeError; use rmcp::transport::streamable_http_client::StreamableHttpClient; @@ -191,8 +190,15 @@ impl StreamableHttpClient for StreamableHttpClientAdapter { && content_type.starts_with(JSON_MIME_TYPE) { let body = collect_body(&mut body_stream).await?; - let message: ServerJsonRpcMessage = - serde_json::from_slice(&body).map_err(StreamableHttpError::Deserialize)?; + let message = match serde_json::from_slice(&body) { + Ok(message) => message, + Err(_error) if is_retryable_http_status(response.status) => { + return Err(StreamableHttpError::Client( + StreamableHttpClientAdapterError::RetryableHttpStatus(response.status), + )); + } + Err(error) => return Err(StreamableHttpError::Deserialize(error)), + }; return Ok(StreamableHttpPostResponse::Json(message, session_id)); } @@ -520,6 +526,7 @@ mod tests { use rmcp::model::JsonRpcError; use rmcp::model::PingRequest; use rmcp::model::RequestId; + use rmcp::model::ServerJsonRpcMessage; use rmcp::transport::streamable_http_client::StreamableHttpPostResponse; use serde_json::json; use tokio::net::TcpListener; @@ -572,6 +579,46 @@ mod tests { Ok(()) } + #[tokio::test] + async fn post_message_retries_non_json_rpc_json_error_body() -> anyhow::Result<()> { + let listener = TcpListener::bind("127.0.0.1:0").await?; + let address = listener.local_addr()?; + let app = Router::new().route("/", post(non_json_rpc_json_error_response)); + let server = tokio::spawn(async move { axum::serve(listener, app).await }); + + let adapter = StreamableHttpClientAdapter::new( + Environment::default_for_tests().get_http_client(), + HeaderMap::new(), + /*auth_provider*/ None, + ); + let request = ClientJsonRpcMessage::request( + ClientRequest::PingRequest(PingRequest::default()), + RequestId::Number(1), + ); + + let result = adapter + .post_message( + Arc::from(format!("http://{address}/")), + request, + /*session_id*/ None, + /*auth_token*/ None, + HashMap::new(), + ) + .await; + + server.abort(); + + let Err(StreamableHttpError::Client( + StreamableHttpClientAdapterError::RetryableHttpStatus(status), + )) = result + else { + panic!("expected retryable HTTP status error"); + }; + assert_eq!(status, StatusCode::SERVICE_UNAVAILABLE.as_u16()); + + Ok(()) + } + async fn json_error_response() -> impl IntoResponse { ( StatusCode::INTERNAL_SERVER_ERROR, @@ -586,4 +633,14 @@ mod tests { })), ) } + + async fn non_json_rpc_json_error_response() -> impl IntoResponse { + ( + StatusCode::SERVICE_UNAVAILABLE, + [(CONTENT_TYPE, JSON_MIME_TYPE)], + Json(json!({ + "error": "service temporarily unavailable", + })), + ) + } } diff --git a/codex-rs/rmcp-client/src/rmcp_client.rs b/codex-rs/rmcp-client/src/rmcp_client.rs index bd6fb7b76f4..c909c88bdaf 100644 --- a/codex-rs/rmcp-client/src/rmcp_client.rs +++ b/codex-rs/rmcp-client/src/rmcp_client.rs @@ -1222,7 +1222,10 @@ impl RmcpClient { fn is_retryable_client_initialize_error(error: &rmcp::service::ClientInitializeError) -> bool { match error { rmcp::service::ClientInitializeError::TransportError { error, context } - if context.as_ref() == "send initialize request" => + if matches!( + context.as_ref(), + "send initialize request" | "send initialized notification" + ) => { error .error diff --git a/codex-rs/rmcp-client/src/rmcp_client_tests.rs b/codex-rs/rmcp-client/src/rmcp_client_tests.rs index 21f0cc62dd2..f8689a73391 100644 --- a/codex-rs/rmcp-client/src/rmcp_client_tests.rs +++ b/codex-rs/rmcp-client/src/rmcp_client_tests.rs @@ -1,7 +1,9 @@ +use std::any::TypeId; use std::collections::BTreeMap; use std::time::Duration; use pretty_assertions::assert_eq; +use rmcp::transport::DynamicTransportError; use tokio::time; use super::*; @@ -75,3 +77,34 @@ fn initialize_metric_tags_record_retry_exhaustion() { ]) ); } + +#[test] +fn retryable_initialize_error_includes_initialized_notification_context() { + let contexts = [ + "send initialize request", + "send initialized notification", + "receive initialize response", + ]; + + assert_eq!( + contexts.map(|context| { + RmcpClient::is_retryable_client_initialize_error(&retryable_initialize_error(context)) + }), + [true, true, false], + ); +} + +fn retryable_initialize_error(context: &'static str) -> rmcp::service::ClientInitializeError { + rmcp::service::ClientInitializeError::TransportError { + error: DynamicTransportError::from_parts( + "streamable_http", + TypeId::of::<()>(), + Box::new(StreamableHttpError::Client( + StreamableHttpClientAdapterError::RetryableHttpStatus( + reqwest::StatusCode::SERVICE_UNAVAILABLE.as_u16(), + ), + )), + ), + context: context.into(), + } +} diff --git a/sdk/python/scripts/update_sdk_artifacts.py b/sdk/python/scripts/update_sdk_artifacts.py index 1742ced547e..a4af5d605e1 100755 --- a/sdk/python/scripts/update_sdk_artifacts.py +++ b/sdk/python/scripts/update_sdk_artifacts.py @@ -1163,7 +1163,9 @@ def generate_public_api_flat_methods() -> None: turn_start_fields = _load_public_fields( "openai_codex.generated.v2_all", "TurnStartParams", - exclude={"thread_id", "input", *approval_fields}, + # Keep the wire model current without exposing this app-server field + # through the ergonomic Python API yet. + exclude={"thread_id", "input", "client_user_message_id", *approval_fields}, ) turn_start_fields = _replace_public_sandbox_field(turn_start_fields, wire_name="sandbox_policy") @@ -1267,7 +1269,7 @@ def build_parser() -> argparse.ArgumentParser: "--platform-tag", help=( "Optional wheel platform tag override, for example " - "macosx_11_0_arm64 or musllinux_1_1_x86_64." + "macosx_11_0_arm64 or manylinux_2_17_x86_64." ), ) return parser diff --git a/sdk/python/src/openai_codex/api.py b/sdk/python/src/openai_codex/api.py index 560d053aaef..6fc9a8243d6 100644 --- a/sdk/python/src/openai_codex/api.py +++ b/sdk/python/src/openai_codex/api.py @@ -542,7 +542,6 @@ def run( input: RunInput, *, approval_mode: ApprovalMode | None = None, - client_user_message_id: str | None = None, cwd: str | None = None, effort: ReasoningEffort | None = None, model: str | None = None, @@ -556,7 +555,6 @@ def run( turn = self.turn( input, approval_mode=approval_mode, - client_user_message_id=client_user_message_id, cwd=cwd, effort=effort, model=model, @@ -578,7 +576,6 @@ def turn( input: RunInput, *, approval_mode: ApprovalMode | None = None, - client_user_message_id: str | None = None, cwd: str | None = None, effort: ReasoningEffort | None = None, model: str | None = None, @@ -596,7 +593,6 @@ def turn( input=wire_input, approval_policy=approval_policy, approvals_reviewer=approvals_reviewer, - client_user_message_id=client_user_message_id, cwd=cwd, effort=effort, model=model, @@ -634,7 +630,6 @@ async def run( input: RunInput, *, approval_mode: ApprovalMode | None = None, - client_user_message_id: str | None = None, cwd: str | None = None, effort: ReasoningEffort | None = None, model: str | None = None, @@ -648,7 +643,6 @@ async def run( turn = await self.turn( input, approval_mode=approval_mode, - client_user_message_id=client_user_message_id, cwd=cwd, effort=effort, model=model, @@ -670,7 +664,6 @@ async def turn( input: RunInput, *, approval_mode: ApprovalMode | None = None, - client_user_message_id: str | None = None, cwd: str | None = None, effort: ReasoningEffort | None = None, model: str | None = None, @@ -689,7 +682,6 @@ async def turn( input=wire_input, approval_policy=approval_policy, approvals_reviewer=approvals_reviewer, - client_user_message_id=client_user_message_id, cwd=cwd, effort=effort, model=model, diff --git a/sdk/python/tests/test_artifact_workflow_and_binaries.py b/sdk/python/tests/test_artifact_workflow_and_binaries.py index bc57e4b5020..501f1e100a7 100644 --- a/sdk/python/tests/test_artifact_workflow_and_binaries.py +++ b/sdk/python/tests/test_artifact_workflow_and_binaries.py @@ -14,6 +14,18 @@ ROOT = Path(__file__).resolve().parents[1] +def _load_root_format_script_module(): + """Load the root formatter driver so tests exercise its real command graph.""" + script_path = ROOT.parents[1] / "scripts" / "format.py" + spec = importlib.util.spec_from_file_location("format_repo", script_path) + if spec is None or spec.loader is None: + raise AssertionError(f"Failed to load script module: {script_path}") + module = importlib.util.module_from_spec(spec) + sys.modules[spec.name] = module + spec.loader.exec_module(module) + return module + + def _load_update_script_module(): """Load the maintenance script as a module so tests exercise real helpers.""" script_path = ROOT / "scripts" / "update_sdk_artifacts.py" @@ -68,40 +80,135 @@ def test_generation_has_single_maintenance_entrypoint_script() -> None: assert scripts == ["update_sdk_artifacts.py"] -def test_root_fmt_recipe_formats_rust_and_python_sdk() -> None: - """The repo fmt command should work from Rust and Python SDK directories.""" +def test_root_fmt_recipes_use_shared_formatter_driver() -> None: + """The root formatting recipes should use the shared cross-platform driver.""" justfile = ROOT.parents[1] / "justfile" lines = justfile.read_text().splitlines() fmt_index = lines.index("fmt:") + fmt_check_index = lines.index("fmt-check:") next_recipe_index = next( index - for index in range(fmt_index + 1, len(lines)) + for index in range(fmt_check_index + 1, len(lines)) if lines[index] and not lines[index].startswith((" ", "\t", "#")) ) - fmt_recipe = lines[fmt_index:next_recipe_index] actual = { "working_directory": lines[0], - "previous_attribute": lines[fmt_index - 1], - "commands": [line.strip() for line in fmt_recipe[1:] if line.strip()], + "fmt_comment": next(line for line in reversed(lines[:fmt_index]) if line.startswith("#")), + "fmt_commands": [ + line.strip() + for line in lines[fmt_index + 1 : fmt_check_index] + if line.strip() and not line.startswith("#") + ], + "fmt_check_comment": next( + line for line in reversed(lines[:fmt_check_index]) if line.startswith("#") + ), + "fmt_check_commands": [ + line.strip() for line in lines[fmt_check_index + 1 : next_recipe_index] if line.strip() + ], } expected = { "working_directory": 'set working-directory := "codex-rs"', - "previous_attribute": "# Format Rust and Python SDK code.", - "commands": [ - "cargo fmt -- --config imports_granularity=Item 2>/dev/null", - "uv run --frozen --project ../sdk/python --extra dev ruff check --fix --fix-only ../sdk/python", - "uv run --frozen --project ../sdk/python --extra dev ruff format ../sdk/python", - ], + "fmt_comment": "# Format the justfile, Rust, Python SDK code, and Python scripts.", + "fmt_commands": ["{{ python }} ../scripts/format.py"], + "fmt_check_comment": "# Check formatting without modifying files.", + "fmt_check_commands": ["{{ python }} ../scripts/format.py --check"], } assert actual == expected, ( - "The root `just fmt` recipe must run Rust fmt and Python SDK Ruff. " - "Fix the `fmt` recipe in `justfile`, then run `just fmt`.\n" + "The root formatting recipes must use the shared formatter driver. " + "Fix the recipes in `justfile`, then run `just fmt`.\n" f"Expected: {json.dumps(expected, indent=2)}\n" f"Actual: {json.dumps(actual, indent=2)}" ) +def test_root_format_driver_covers_all_formatter_groups() -> None: + """The shared driver should retain every formatter in both modes.""" + script = _load_root_format_script_module() + formatters = script.formatter_groups(check=False) + checks = script.formatter_groups(check=True) + + assert [group.name for group in formatters] == [ + "Just", + "Rust", + "Python SDK", + "Python scripts", + ] + assert [group.name for group in checks] == [group.name for group in formatters] + assert [len(group.commands) for group in formatters] == [1, 1, 2, 1] + assert [len(group.commands) for group in checks] == [ + len(group.commands) for group in formatters + ] + sdk_uv_run_args = ( + "uv", + "run", + "--frozen", + "--project", + "sdk/python", + "--no-sync", + "--with", + "ruff", + ) + scripts_uv_run_args = ( + "uv", + "run", + "--frozen", + "--project", + "scripts", + "--no-sync", + "--with", + "ruff", + ) + assert all( + command.args[: len(sdk_uv_run_args)] == sdk_uv_run_args + for group in (formatters[2], checks[2]) + for command in group.commands + ) + assert all( + command.args[: len(scripts_uv_run_args)] == scripts_uv_run_args + for group in (formatters[3], checks[3]) + for command in group.commands + ) + assert formatters[2].commands[0].args[-5:] == ( + "ruff", + "check", + "--fix", + "--fix-only", + "sdk/python", + ) + assert checks[2].commands[0].args[-4:] == ( + "ruff", + "check", + "--diff", + "sdk/python", + ) + assert formatters[0].commands[-1].args == ("just", "--unstable", "--fmt") + assert checks[0].commands[-1].args == ("just", "--unstable", "--fmt", "--check") + assert formatters[1].commands[-1].args == ( + "cargo", + "fmt", + "--", + "--config", + "imports_granularity=Item", + ) + assert checks[1].commands[-1].args == ( + "cargo", + "fmt", + "--", + "--config", + "imports_granularity=Item", + "--check", + ) + assert [group.commands[-1].args[-3:] for group in formatters[2:]] == [ + ("ruff", "format", "sdk/python"), + ("ruff", "format", "scripts"), + ] + assert [group.commands[-1].args[-4:] for group in checks[2:]] == [ + ("ruff", "format", "--check", "sdk/python"), + ("ruff", "format", "--check", "scripts"), + ] + + def test_generate_types_wires_all_generation_steps() -> None: """The type generation command should refresh every schema-derived artifact.""" source = (ROOT / "scripts" / "update_sdk_artifacts.py").read_text() @@ -488,11 +595,11 @@ def test_stage_runtime_release_can_pin_wheel_platform_tag(tmp_path: Path) -> Non tmp_path / "runtime-stage", "0.116.0a1", package_archive, - platform_tag="musllinux_1_1_x86_64", + platform_tag="manylinux_2_17_x86_64", ) pyproject = (staged / "pyproject.toml").read_text() - assert 'platform-tag = "musllinux_1_1_x86_64"' in pyproject + assert 'platform-tag = "manylinux_2_17_x86_64"' in pyproject def test_stage_runtime_release_rejects_incomplete_package_layout(tmp_path: Path) -> None: @@ -581,7 +688,7 @@ def test_sdk_beta_release_can_pin_stable_runtime(tmp_path: Path) -> None: ) runtime_stage = script.stage_python_runtime_package( tmp_path / "runtime-stage", - "0.132.0", + "0.137.0a4", package_archive, ) @@ -594,7 +701,7 @@ def test_sdk_beta_release_can_pin_stable_runtime(tmp_path: Path) -> None: "sdk_dependencies": sdk_pyproject["project"]["dependencies"], } == { "sdk_version": "0.1.0b1", - "runtime_version": "0.132.0", + "runtime_version": "0.137.0a4", "sdk_dependencies": [ "pydantic>=2.12", "openai-codex-cli-bin==0.137.0a4", @@ -656,7 +763,7 @@ def test_stage_runtime_stages_package_without_type_generation(tmp_path: Path) -> "--codex-version", "rust-v0.116.0-alpha.1", "--platform-tag", - "musllinux_1_1_x86_64", + "manylinux_2_17_x86_64", ] ) @@ -687,7 +794,7 @@ def fake_current_sdk_version() -> str: script.run_command(args, ops) - assert calls == ["stage_runtime:0.116.0a1:musllinux_1_1_x86_64:codex-package.tar.gz"] + assert calls == ["stage_runtime:0.116.0a1:manylinux_2_17_x86_64:codex-package.tar.gz"] def test_default_runtime_is_resolved_from_installed_runtime_package( diff --git a/sdk/python/tests/test_public_api_signatures.py b/sdk/python/tests/test_public_api_signatures.py index 25a3b831d25..c26ac90f6b1 100644 --- a/sdk/python/tests/test_public_api_signatures.py +++ b/sdk/python/tests/test_public_api_signatures.py @@ -371,7 +371,6 @@ def test_generated_public_signatures_are_snake_case_and_typed() -> None: ], Thread.turn: [ "approval_mode", - "client_user_message_id", "cwd", "effort", "model", @@ -383,7 +382,6 @@ def test_generated_public_signatures_are_snake_case_and_typed() -> None: ], Thread.run: [ "approval_mode", - "client_user_message_id", "cwd", "effort", "model", @@ -448,7 +446,6 @@ def test_generated_public_signatures_are_snake_case_and_typed() -> None: ], AsyncThread.turn: [ "approval_mode", - "client_user_message_id", "cwd", "effort", "model", @@ -460,7 +457,6 @@ def test_generated_public_signatures_are_snake_case_and_typed() -> None: ], AsyncThread.run: [ "approval_mode", - "client_user_message_id", "cwd", "effort", "model", From 5540ca946acde3be80c63a3204f6fefe6af170c5 Mon Sep 17 00:00:00 2001 From: Sama Setty Date: Thu, 4 Jun 2026 18:31:21 -0700 Subject: [PATCH 10/15] codex: remove SDK changes from retry PR (#25147) --- sdk/python/pyproject.toml | 6 +- sdk/python/scripts/update_sdk_artifacts.py | 6 +- .../generated/notification_registry.py | 2 - .../src/openai_codex/generated/v2_all.py | 486 +++--------------- .../test_artifact_workflow_and_binaries.py | 160 +----- sdk/python/tests/test_contract_generation.py | 2 +- sdk/python/uv.lock | 20 +- 7 files changed, 118 insertions(+), 564 deletions(-) diff --git a/sdk/python/pyproject.toml b/sdk/python/pyproject.toml index b70df0cb534..9c93be69a5d 100644 --- a/sdk/python/pyproject.toml +++ b/sdk/python/pyproject.toml @@ -16,7 +16,7 @@ classifiers = [ "Intended Audience :: Developers", "Topic :: Software Development :: Libraries :: Python Modules", ] -dependencies = ["pydantic>=2.12", "openai-codex-cli-bin==0.137.0a4"] +dependencies = ["pydantic>=2.12", "openai-codex-cli-bin==0.132.0"] [project.urls] Homepage = "https://github.com/openai/codex" @@ -78,10 +78,10 @@ combine-as-imports = true [tool.uv] exclude-newer = "7 days" -exclude-newer-package = { openai-codex-cli-bin = "2026-06-03T19:00:00Z" } +exclude-newer-package = { openai-codex-cli-bin = "2026-05-20T21:00:00Z" } index-strategy = "first-index" [tool.uv.pip] exclude-newer = "7 days" -exclude-newer-package = { openai-codex-cli-bin = "2026-06-03T19:00:00Z" } +exclude-newer-package = { openai-codex-cli-bin = "2026-05-20T21:00:00Z" } index-strategy = "first-index" diff --git a/sdk/python/scripts/update_sdk_artifacts.py b/sdk/python/scripts/update_sdk_artifacts.py index a4af5d605e1..1742ced547e 100755 --- a/sdk/python/scripts/update_sdk_artifacts.py +++ b/sdk/python/scripts/update_sdk_artifacts.py @@ -1163,9 +1163,7 @@ def generate_public_api_flat_methods() -> None: turn_start_fields = _load_public_fields( "openai_codex.generated.v2_all", "TurnStartParams", - # Keep the wire model current without exposing this app-server field - # through the ergonomic Python API yet. - exclude={"thread_id", "input", "client_user_message_id", *approval_fields}, + exclude={"thread_id", "input", *approval_fields}, ) turn_start_fields = _replace_public_sandbox_field(turn_start_fields, wire_name="sandbox_policy") @@ -1269,7 +1267,7 @@ def build_parser() -> argparse.ArgumentParser: "--platform-tag", help=( "Optional wheel platform tag override, for example " - "macosx_11_0_arm64 or manylinux_2_17_x86_64." + "macosx_11_0_arm64 or musllinux_1_1_x86_64." ), ) return parser diff --git a/sdk/python/src/openai_codex/generated/notification_registry.py b/sdk/python/src/openai_codex/generated/notification_registry.py index d5e620a7a3c..c55eb5b9b75 100644 --- a/sdk/python/src/openai_codex/generated/notification_registry.py +++ b/sdk/python/src/openai_codex/generated/notification_registry.py @@ -57,7 +57,6 @@ from .v2_all import ThreadRealtimeStartedNotification from .v2_all import ThreadRealtimeTranscriptDeltaNotification from .v2_all import ThreadRealtimeTranscriptDoneNotification -from .v2_all import ThreadSettingsUpdatedNotification from .v2_all import ThreadStartedNotification from .v2_all import ThreadStatusChangedNotification from .v2_all import ThreadTokenUsageUpdatedNotification @@ -123,7 +122,6 @@ "thread/realtime/started": ThreadRealtimeStartedNotification, "thread/realtime/transcript/delta": ThreadRealtimeTranscriptDeltaNotification, "thread/realtime/transcript/done": ThreadRealtimeTranscriptDoneNotification, - "thread/settings/updated": ThreadSettingsUpdatedNotification, "thread/started": ThreadStartedNotification, "thread/status/changed": ThreadStatusChangedNotification, "thread/tokenUsage/updated": ThreadTokenUsageUpdatedNotification, diff --git a/sdk/python/src/openai_codex/generated/v2_all.py b/sdk/python/src/openai_codex/generated/v2_all.py index 15ede1801cf..85120b82932 100644 --- a/sdk/python/src/openai_codex/generated/v2_all.py +++ b/sdk/python/src/openai_codex/generated/v2_all.py @@ -56,7 +56,7 @@ class ActivePermissionProfile(BaseModel): extends: Annotated[ str | None, Field( - description="Parent profile identifier from the selected permissions profile's `extends` setting, when present." + description="Parent profile identifier once permissions profiles support inheritance. This is currently always `null`." ), ] = None id: Annotated[ @@ -77,11 +77,6 @@ class AddCreditsNudgeEmailStatus(Enum): cooldown_active = "cooldown_active" -class AdditionalContextKind(Enum): - untrusted = "untrusted" - application = "application" - - class AdditionalNetworkPermissions(BaseModel): model_config = ConfigDict( populate_by_name=True, @@ -256,11 +251,6 @@ class AuthMode(Enum): agent_identity = "agentIdentity" -class AutoCompactTokenLimitScope(Enum): - total = "total" - body_after_prefix = "body_after_prefix" - - class AutoReviewDecisionSource(RootModel[Literal["agent"]]): model_config = ConfigDict( populate_by_name=True, @@ -573,13 +563,6 @@ class CommandMigration(BaseModel): name: str -class ComputerUseRequirements(BaseModel): - model_config = ConfigDict( - populate_by_name=True, - ) - allow_locked_computer_use: Annotated[bool | None, Field(alias="allowLockedComputerUse")] = None - - class MdmConfigLayerSource(BaseModel): model_config = ConfigDict( populate_by_name=True, @@ -602,22 +585,6 @@ class SystemConfigLayerSource(BaseModel): type: Annotated[Literal["system"], Field(title="SystemConfigLayerSourceType")] -class EnterpriseManagedConfigLayerSource(BaseModel): - model_config = ConfigDict( - populate_by_name=True, - ) - id: Annotated[str, Field(description="Stable identifier for the delivered layer.")] - name: Annotated[ - str, - Field( - description="Admin-facing name for the delivered layer. This is surfaced in diagnostics so users know which cloud layer needs administrator attention." - ), - ] - type: Annotated[ - Literal["enterpriseManaged"], Field(title="EnterpriseManagedConfigLayerSourceType") - ] - - class UserConfigLayerSource(BaseModel): model_config = ConfigDict( populate_by_name=True, @@ -677,7 +644,6 @@ class ConfigLayerSource( RootModel[ MdmConfigLayerSource | SystemConfigLayerSource - | EnterpriseManagedConfigLayerSource | UserConfigLayerSource | ProjectConfigLayerSource | SessionFlagsConfigLayerSource @@ -691,7 +657,6 @@ class ConfigLayerSource( root: ( MdmConfigLayerSource | SystemConfigLayerSource - | EnterpriseManagedConfigLayerSource | UserConfigLayerSource | ProjectConfigLayerSource | SessionFlagsConfigLayerSource @@ -710,7 +675,7 @@ class ConfigReadParams(BaseModel): description="Optional working directory to resolve project config layers. If specified, return the effective config as seen from that directory (i.e., including any project layers between `cwd` and the project/repo root)." ), ] = None - include_layers: Annotated[bool | None, Field(alias="includeLayers")] = None + include_layers: Annotated[bool | None, Field(alias="includeLayers")] = False class CommandConfiguredHookHandler(BaseModel): @@ -948,7 +913,7 @@ class FeedbackUploadParams(BaseModel): ) classification: str extra_log_files: Annotated[list[str] | None, Field(alias="extraLogFiles")] = None - include_logs: Annotated[bool | None, Field(alias="includeLogs")] = None + include_logs: Annotated[bool, Field(alias="includeLogs")] reason: str | None = None tags: dict[str, Any] | None = None thread_id: Annotated[str | None, Field(alias="threadId")] = None @@ -974,7 +939,7 @@ class FileChangeOutputDeltaNotification(BaseModel): class FileSystemAccessMode(Enum): read = "read" write = "write" - deny = "deny" + none = "none" class PathFileSystemPath(BaseModel): @@ -1311,17 +1276,6 @@ class InputTextFunctionCallOutputContentItem(BaseModel): ] -class EncryptedContentFunctionCallOutputContentItem(BaseModel): - model_config = ConfigDict( - populate_by_name=True, - ) - encrypted_content: str - type: Annotated[ - Literal["encrypted_content"], - Field(title="EncryptedContentFunctionCallOutputContentItemType"), - ] - - class FuzzyFileSearchMatchType(Enum): file = "file" directory = "directory" @@ -1381,7 +1335,7 @@ class GetAccountParams(BaseModel): alias="refreshToken", description="When `true`, requests a proactive token refresh before returning.\n\nIn managed auth mode this triggers the normal refresh-token flow. In external auth mode this flag is ignored. Clients should refresh tokens themselves and call `account/login/start` with `chatgptAuthTokens`.", ), - ] = None + ] = False class GitInfo(BaseModel): @@ -1471,8 +1425,6 @@ class HookEventName(Enum): post_compact = "postCompact" session_start = "sessionStart" user_prompt_submit = "userPromptSubmit" - subagent_start = "subagentStart" - subagent_stop = "subagentStop" stop = "stop" @@ -1531,7 +1483,6 @@ class HookSource(Enum): session_flags = "sessionFlags" plugin = "plugin" cloud_requirements = "cloudRequirements" - cloud_managed_config = "cloudManagedConfig" legacy_managed_config_file = "legacyManagedConfigFile" legacy_managed_config_mdm = "legacyManagedConfigMdm" unknown = "unknown" @@ -1555,8 +1506,6 @@ class HooksListParams(BaseModel): class ImageDetail(Enum): - auto = "auto" - low = "low" high = "high" original = "original" @@ -1793,8 +1742,6 @@ class ManagedHooksRequirements(BaseModel): pre_tool_use: Annotated[list[ConfiguredHookMatcherGroup], Field(alias="PreToolUse")] session_start: Annotated[list[ConfiguredHookMatcherGroup], Field(alias="SessionStart")] stop: Annotated[list[ConfiguredHookMatcherGroup], Field(alias="Stop")] - subagent_start: Annotated[list[ConfiguredHookMatcherGroup], Field(alias="SubagentStart")] - subagent_stop: Annotated[list[ConfiguredHookMatcherGroup], Field(alias="SubagentStop")] user_prompt_submit: Annotated[list[ConfiguredHookMatcherGroup], Field(alias="UserPromptSubmit")] managed_dir: Annotated[str | None, Field(alias="managedDir")] = None windows_managed_dir: Annotated[str | None, Field(alias="windowsManagedDir")] = None @@ -1888,18 +1835,6 @@ class McpResourceReadParams(BaseModel): uri: str -class McpServerInfo(BaseModel): - model_config = ConfigDict( - populate_by_name=True, - ) - description: str | None = None - icons: list | None = None - name: str - title: str | None = None - version: str - website_url: Annotated[str | None, Field(alias="websiteUrl")] = None - - class McpServerMigration(BaseModel): model_config = ConfigDict( populate_by_name=True, @@ -2207,7 +2142,7 @@ class NetworkRequirements(BaseModel): class NetworkUnixSocketPermission(Enum): allow = "allow" - deny = "deny" + none = "none" class NonSteerableTurnKind(Enum): @@ -2253,32 +2188,6 @@ class PatchChangeKind( root: AddPatchChangeKind | DeletePatchChangeKind | UpdatePatchChangeKind -class PermissionProfileListParams(BaseModel): - model_config = ConfigDict( - populate_by_name=True, - ) - cursor: Annotated[ - str | None, Field(description="Opaque pagination cursor returned by a previous call.") - ] = None - cwd: Annotated[ - str | None, - Field(description="Optional working directory to resolve project config layers."), - ] = None - limit: Annotated[ - int | None, Field(description="Optional page size; defaults to the full result set.", ge=0) - ] = None - - -class PermissionProfileSummary(BaseModel): - model_config = ConfigDict( - populate_by_name=True, - ) - description: Annotated[ - str | None, Field(description="Optional user-facing description for display in clients.") - ] = None - id: Annotated[str, Field(description="Available permission profile identifier.")] - - class Personality(Enum): none = "none" friendly = "friendly" @@ -2423,7 +2332,6 @@ class PluginInterface(BaseModel): class PluginListMarketplaceKind(Enum): local = "local" - vertical = "vertical" workspace_directory = "workspace-directory" shared_with_me = "shared-with-me" @@ -3561,20 +3469,6 @@ class SkillsConfigWriteResponse(BaseModel): effective_enabled: Annotated[bool, Field(alias="effectiveEnabled")] -class SkillsExtraRootsSetParams(BaseModel): - model_config = ConfigDict( - populate_by_name=True, - ) - extra_roots: Annotated[list[AbsolutePathBuf], Field(alias="extraRoots")] - - -class SkillsExtraRootsSetResponse(BaseModel): - pass - model_config = ConfigDict( - populate_by_name=True, - ) - - class SkillsListParams(BaseModel): model_config = ConfigDict( populate_by_name=True, @@ -3597,16 +3491,6 @@ class SortDirection(Enum): desc = "desc" -class SpendControlLimitSnapshot(BaseModel): - model_config = ConfigDict( - populate_by_name=True, - ) - limit: str - remaining_percent: Annotated[int, Field(alias="remainingPercent")] - resets_at: Annotated[int, Field(alias="resetsAt")] - used: str - - class SubAgentSourceValue(Enum): review = "review" compact = "compact" @@ -3740,20 +3624,6 @@ class ThreadCompactStartResponse(BaseModel): ) -class ThreadGoalClearParams(BaseModel): - model_config = ConfigDict( - populate_by_name=True, - ) - thread_id: Annotated[str, Field(alias="threadId")] - - -class ThreadGoalClearResponse(BaseModel): - model_config = ConfigDict( - populate_by_name=True, - ) - cleared: bool - - class ThreadGoalClearedNotification(BaseModel): model_config = ConfigDict( populate_by_name=True, @@ -3761,13 +3631,6 @@ class ThreadGoalClearedNotification(BaseModel): thread_id: Annotated[str, Field(alias="threadId")] -class ThreadGoalGetParams(BaseModel): - model_config = ConfigDict( - populate_by_name=True, - ) - thread_id: Annotated[str, Field(alias="threadId")] - - class ThreadGoalStatus(Enum): active = "active" paused = "paused" @@ -3886,7 +3749,6 @@ class McpToolCallThreadItem(BaseModel): error: McpToolCallError | None = None id: str mcp_app_resource_uri: Annotated[str | None, Field(alias="mcpAppResourceUri")] = None - plugin_id: Annotated[str | None, Field(alias="pluginId")] = None result: McpToolCallResult | None = None server: str status: McpToolCallStatus @@ -4060,7 +3922,7 @@ class ThreadReadParams(BaseModel): alias="includeTurns", description="When true, include turns and their items from rollout history.", ), - ] = None + ] = False thread_id: Annotated[str, Field(alias="threadId")] @@ -4674,19 +4536,10 @@ class AccountUpdatedNotification(BaseModel): plan_type: Annotated[PlanType | None, Field(alias="planType")] = None -class AdditionalContextEntry(BaseModel): - model_config = ConfigDict( - populate_by_name=True, - ) - kind: AdditionalContextKind - value: str - - class AppConfig(BaseModel): model_config = ConfigDict( populate_by_name=True, ) - approvals_reviewer: ApprovalsReviewer | None = None default_tools_approval_mode: AppToolApproval | None = None default_tools_enabled: bool | None = None destructive_enabled: bool | None = None @@ -4776,24 +4629,6 @@ class ThreadNameSetRequest(BaseModel): params: ThreadSetNameParams -class ThreadGoalGetRequest(BaseModel): - model_config = ConfigDict( - populate_by_name=True, - ) - id: RequestId - method: Annotated[Literal["thread/goal/get"], Field(title="Thread/goal/getRequestMethod")] - params: ThreadGoalGetParams - - -class ThreadGoalClearRequest(BaseModel): - model_config = ConfigDict( - populate_by_name=True, - ) - id: RequestId - method: Annotated[Literal["thread/goal/clear"], Field(title="Thread/goal/clearRequestMethod")] - params: ThreadGoalClearParams - - class ThreadMetadataUpdateRequest(BaseModel): model_config = ConfigDict( populate_by_name=True, @@ -4895,17 +4730,6 @@ class SkillsListRequest(BaseModel): params: SkillsListParams -class SkillsExtraRootsSetRequest(BaseModel): - model_config = ConfigDict( - populate_by_name=True, - ) - id: RequestId - method: Annotated[ - Literal["skills/extraRoots/set"], Field(title="Skills/extraRoots/setRequestMethod") - ] - params: SkillsExtraRootsSetParams - - class HooksListRequest(BaseModel): model_config = ConfigDict( populate_by_name=True, @@ -5171,17 +4995,6 @@ class ExperimentalFeatureListRequest(BaseModel): params: ExperimentalFeatureListParams -class PermissionProfileListRequest(BaseModel): - model_config = ConfigDict( - populate_by_name=True, - ) - id: RequestId - method: Annotated[ - Literal["permissionProfile/list"], Field(title="PermissionProfile/listRequestMethod") - ] - params: PermissionProfileListParams - - class ExperimentalFeatureEnablementSetRequest(BaseModel): model_config = ConfigDict( populate_by_name=True, @@ -5618,22 +5431,16 @@ class ConfigRequirements(BaseModel): model_config = ConfigDict( populate_by_name=True, ) - allow_appshots: Annotated[bool | None, Field(alias="allowAppshots")] = None allow_managed_hooks_only: Annotated[bool | None, Field(alias="allowManagedHooksOnly")] = None allowed_approval_policies: Annotated[ list[AskForApproval] | None, Field(alias="allowedApprovalPolicies") ] = None - allowed_permissions: Annotated[list[str] | None, Field(alias="allowedPermissions")] = None allowed_sandbox_modes: Annotated[ list[SandboxMode] | None, Field(alias="allowedSandboxModes") ] = None allowed_web_search_modes: Annotated[ list[WebSearchMode] | None, Field(alias="allowedWebSearchModes") ] = None - allowed_windows_sandbox_implementations: Annotated[ - list[WindowsSandboxSetupMode] | None, Field(alias="allowedWindowsSandboxImplementations") - ] = None - computer_use: Annotated[ComputerUseRequirements | None, Field(alias="computerUse")] = None enforce_residency: Annotated[ResidencyRequirement | None, Field(alias="enforceResidency")] = ( None ) @@ -5801,19 +5608,13 @@ class InputImageFunctionCallOutputContentItem(BaseModel): class FunctionCallOutputContentItem( - RootModel[ - InputTextFunctionCallOutputContentItem - | InputImageFunctionCallOutputContentItem - | EncryptedContentFunctionCallOutputContentItem - ] + RootModel[InputTextFunctionCallOutputContentItem | InputImageFunctionCallOutputContentItem] ): model_config = ConfigDict( populate_by_name=True, ) root: Annotated[ - InputTextFunctionCallOutputContentItem - | InputImageFunctionCallOutputContentItem - | EncryptedContentFunctionCallOutputContentItem, + InputTextFunctionCallOutputContentItem | InputImageFunctionCallOutputContentItem, Field( description="Responses API compatible content items that can be returned by a tool call. This is a subset of ContentItem with the types we support as function call outputs." ), @@ -5966,7 +5767,6 @@ class ListMcpServerStatusParams(BaseModel): int | None, Field(description="Optional page size; defaults to a server-defined value.", ge=0), ] = None - thread_id: Annotated[str | None, Field(alias="threadId")] = None class McpResourceReadResponse(BaseModel): @@ -5984,7 +5784,6 @@ class McpServerStatus(BaseModel): name: str resource_templates: Annotated[list[ResourceTemplate], Field(alias="resourceTemplates")] resources: list[Resource] - server_info: Annotated[McpServerInfo | None, Field(alias="serverInfo")] = None tools: dict[str, Tool] @@ -6018,13 +5817,6 @@ class Model(BaseModel): ] = [] availability_nux: Annotated[ModelAvailabilityNux | None, Field(alias="availabilityNux")] = None default_reasoning_effort: Annotated[ReasoningEffort, Field(alias="defaultReasoningEffort")] - default_service_tier: Annotated[ - str | None, - Field( - alias="defaultServiceTier", - description="Catalog default service tier id for this model, when one is configured.", - ), - ] = None description: str display_name: Annotated[str, Field(alias="displayName")] hidden: bool @@ -6067,20 +5859,6 @@ class OverriddenMetadata(BaseModel): overriding_layer: Annotated[ConfigLayerMetadata, Field(alias="overridingLayer")] -class PermissionProfileListResponse(BaseModel): - model_config = ConfigDict( - populate_by_name=True, - ) - data: list[PermissionProfileSummary] - next_cursor: Annotated[ - str | None, - Field( - alias="nextCursor", - description="Opaque cursor to pass to the next call to continue after the last item. If None, there are no more items to return.", - ), - ] = None - - class PluginSharePrincipal(BaseModel): model_config = ConfigDict( populate_by_name=True, @@ -6148,9 +5926,6 @@ class RateLimitSnapshot(BaseModel): populate_by_name=True, ) credits: CreditsSnapshot | None = None - individual_limit: Annotated[ - SpendControlLimitSnapshot | None, Field(alias="individualLimit") - ] = None limit_id: Annotated[str | None, Field(alias="limitId")] = None limit_name: Annotated[str | None, Field(alias="limitName")] = None plan_type: Annotated[PlanType | None, Field(alias="planType")] = None @@ -6557,30 +6332,6 @@ class ThreadGoal(BaseModel): updated_at: Annotated[int, Field(alias="updatedAt")] -class ThreadGoalGetResponse(BaseModel): - model_config = ConfigDict( - populate_by_name=True, - ) - goal: ThreadGoal | None = None - - -class ThreadGoalSetParams(BaseModel): - model_config = ConfigDict( - populate_by_name=True, - ) - objective: str | None = None - status: ThreadGoalStatus | None = None - thread_id: Annotated[str, Field(alias="threadId")] - token_budget: Annotated[int | None, Field(alias="tokenBudget")] = None - - -class ThreadGoalSetResponse(BaseModel): - model_config = ConfigDict( - populate_by_name=True, - ) - goal: ThreadGoal - - class ThreadGoalUpdatedNotification(BaseModel): model_config = ConfigDict( populate_by_name=True, @@ -6594,7 +6345,6 @@ class UserMessageThreadItem(BaseModel): model_config = ConfigDict( populate_by_name=True, ) - client_id: Annotated[str | None, Field(alias="clientId")] = None content: list[UserInput] id: str type: Annotated[Literal["userMessage"], Field(title="UserMessageThreadItemType")] @@ -6786,55 +6536,6 @@ class ThreadListParams(BaseModel): ] = None -class ThreadResumeInitialTurnsPageParams(BaseModel): - model_config = ConfigDict( - populate_by_name=True, - ) - items_view: Annotated[ - TurnItemsView | None, - Field( - alias="itemsView", - description="How much item detail to include for each returned turn; defaults to summary.", - ), - ] = None - limit: Annotated[int | None, Field(description="Optional turn page size.", ge=0)] = None - sort_direction: Annotated[ - SortDirection | None, - Field( - alias="sortDirection", - description="Optional turn pagination direction; defaults to descending.", - ), - ] = None - - -class ThreadSettings(BaseModel): - model_config = ConfigDict( - populate_by_name=True, - ) - active_permission_profile: Annotated[ - ActivePermissionProfile | None, Field(alias="activePermissionProfile") - ] = None - approval_policy: Annotated[AskForApproval, Field(alias="approvalPolicy")] - approvals_reviewer: Annotated[ApprovalsReviewer, Field(alias="approvalsReviewer")] - collaboration_mode: Annotated[CollaborationMode, Field(alias="collaborationMode")] - cwd: AbsolutePathBuf - effort: ReasoningEffort | None = None - model: str - model_provider: Annotated[str, Field(alias="modelProvider")] - personality: Personality | None = None - sandbox_policy: Annotated[SandboxPolicy, Field(alias="sandboxPolicy")] - service_tier: Annotated[str | None, Field(alias="serviceTier")] = None - summary: ReasoningSummary | None = None - - -class ThreadSettingsUpdatedNotification(BaseModel): - model_config = ConfigDict( - populate_by_name=True, - ) - thread_id: Annotated[str, Field(alias="threadId")] - thread_settings: Annotated[ThreadSettings, Field(alias="threadSettings")] - - class ThreadStartParams(BaseModel): model_config = ConfigDict( populate_by_name=True, @@ -6947,7 +6648,6 @@ class TurnStartParams(BaseModel): description="Override where approval requests are routed for review on this turn and subsequent turns.", ), ] = None - client_user_message_id: Annotated[str | None, Field(alias="clientUserMessageId")] = None cwd: Annotated[ str | None, Field(description="Override the working directory for this turn and subsequent turns."), @@ -6996,7 +6696,6 @@ class TurnSteerParams(BaseModel): model_config = ConfigDict( populate_by_name=True, ) - client_user_message_id: Annotated[str | None, Field(alias="clientUserMessageId")] = None expected_turn_id: Annotated[ str, Field( @@ -7104,15 +6803,6 @@ class ThreadForkRequest(BaseModel): params: ThreadForkParams -class ThreadGoalSetRequest(BaseModel): - model_config = ConfigDict( - populate_by_name=True, - ) - id: RequestId - method: Annotated[Literal["thread/goal/set"], Field(title="Thread/goal/setRequestMethod")] - params: ThreadGoalSetParams - - class ThreadListRequest(BaseModel): model_config = ConfigDict( populate_by_name=True, @@ -7201,41 +6891,6 @@ class ConfigValueWriteRequest(BaseModel): params: ConfigValueWriteParams -class Config(BaseModel): - model_config = ConfigDict( - extra="allow", - populate_by_name=True, - ) - analytics: AnalyticsConfig | None = None - approval_policy: AskForApproval | None = None - approvals_reviewer: Annotated[ - ApprovalsReviewer | None, - Field( - description="[UNSTABLE] Optional default for where approval requests are routed for review." - ), - ] = None - compact_prompt: str | None = None - desktop: dict[str, Any] | None = None - developer_instructions: str | None = None - forced_chatgpt_workspace_id: ForcedChatgptWorkspaceIds | None = None - forced_login_method: ForcedLoginMethod | None = None - instructions: str | None = None - model: str | None = None - model_auto_compact_token_limit: int | None = None - model_auto_compact_token_limit_scope: AutoCompactTokenLimitScope | None = None - model_context_window: int | None = None - model_provider: str | None = None - model_reasoning_effort: ReasoningEffort | None = None - model_reasoning_summary: ReasoningSummary | None = None - model_verbosity: Verbosity | None = None - review_model: str | None = None - sandbox_mode: SandboxMode | None = None - sandbox_workspace_write: SandboxWorkspaceWrite | None = None - service_tier: str | None = None - tools: ToolsV2 | None = None - web_search: WebSearchMode | None = None - - class ConfigBatchWriteParams(BaseModel): model_config = ConfigDict( populate_by_name=True, @@ -7258,15 +6913,6 @@ class ConfigBatchWriteParams(BaseModel): ] = None -class ConfigReadResponse(BaseModel): - model_config = ConfigDict( - populate_by_name=True, - ) - config: Config - layers: list[ConfigLayer] | None = None - origins: dict[str, ConfigLayerMetadata] - - class ConfigWriteResponse(BaseModel): model_config = ConfigDict( populate_by_name=True, @@ -7469,6 +7115,29 @@ class PluginSummary(BaseModel): source: PluginSource +class ProfileV2(BaseModel): + model_config = ConfigDict( + extra="allow", + populate_by_name=True, + ) + approval_policy: AskForApproval | None = None + approvals_reviewer: Annotated[ + ApprovalsReviewer | None, + Field( + description="[UNSTABLE] Optional profile-level override for where approval requests are routed for review. If omitted, the enclosing config default is used." + ), + ] = None + chatgpt_base_url: str | None = None + model: str | None = None + model_provider: str | None = None + model_reasoning_effort: ReasoningEffort | None = None + model_reasoning_summary: ReasoningSummary | None = None + model_verbosity: Verbosity | None = None + service_tier: str | None = None + tools: ToolsV2 | None = None + web_search: WebSearchMode | None = None + + class RequestPermissionProfile(BaseModel): model_config = ConfigDict( extra="forbid", @@ -7560,16 +7229,6 @@ class ThreadGoalUpdatedServerNotification(BaseModel): params: ThreadGoalUpdatedNotification -class ThreadSettingsUpdatedServerNotification(BaseModel): - model_config = ConfigDict( - populate_by_name=True, - ) - method: Annotated[ - Literal["thread/settings/updated"], Field(title="Thread/settings/updatedNotificationMethod") - ] - params: ThreadSettingsUpdatedNotification - - class ThreadTokenUsageUpdatedServerNotification(BaseModel): model_config = ConfigDict( populate_by_name=True, @@ -7734,15 +7393,6 @@ class TurnStartedNotification(BaseModel): turn: Turn -class TurnsPage(BaseModel): - model_config = ConfigDict( - populate_by_name=True, - ) - backwards_cursor: Annotated[str | None, Field(alias="backwardsCursor")] = None - data: list[Turn] - next_cursor: Annotated[str | None, Field(alias="nextCursor")] = None - - class PluginShareSaveRequest(BaseModel): model_config = ConfigDict( populate_by_name=True, @@ -7761,6 +7411,51 @@ class ConfigBatchWriteRequest(BaseModel): params: ConfigBatchWriteParams +class Config(BaseModel): + model_config = ConfigDict( + extra="allow", + populate_by_name=True, + ) + analytics: AnalyticsConfig | None = None + approval_policy: AskForApproval | None = None + approvals_reviewer: Annotated[ + ApprovalsReviewer | None, + Field( + description="[UNSTABLE] Optional default for where approval requests are routed for review." + ), + ] = None + compact_prompt: str | None = None + desktop: dict[str, Any] | None = None + developer_instructions: str | None = None + forced_chatgpt_workspace_id: ForcedChatgptWorkspaceIds | None = None + forced_login_method: ForcedLoginMethod | None = None + instructions: str | None = None + model: str | None = None + model_auto_compact_token_limit: int | None = None + model_context_window: int | None = None + model_provider: str | None = None + model_reasoning_effort: ReasoningEffort | None = None + model_reasoning_summary: ReasoningSummary | None = None + model_verbosity: Verbosity | None = None + profile: str | None = None + profiles: dict[str, ProfileV2] | None = {} + review_model: str | None = None + sandbox_mode: SandboxMode | None = None + sandbox_workspace_write: SandboxWorkspaceWrite | None = None + service_tier: str | None = None + tools: ToolsV2 | None = None + web_search: WebSearchMode | None = None + + +class ConfigReadResponse(BaseModel): + model_config = ConfigDict( + populate_by_name=True, + ) + config: Config + layers: list[ConfigLayer] | None = None + origins: dict[str, ConfigLayerMetadata] + + class ExternalAgentConfigDetectResponse(BaseModel): model_config = ConfigDict( populate_by_name=True, @@ -8044,13 +7739,6 @@ class Thread(BaseModel): ), ] name: Annotated[str | None, Field(description="Optional user-facing thread title.")] = None - parent_thread_id: Annotated[ - str | None, - Field( - alias="parentThreadId", - description="The ID of the parent thread. This will only be set if this thread is a subagent.", - ), - ] = None path: Annotated[str | None, Field(description="[UNSTABLE] Path to the thread on disk.")] = None preview: Annotated[ str, Field(description="Usually the first user message in the thread, if available.") @@ -8204,14 +7892,6 @@ class ThreadRollbackResponse(BaseModel): ] -class ThreadSearchResult(BaseModel): - model_config = ConfigDict( - populate_by_name=True, - ) - snippet: str - thread: Thread - - class ThreadStartResponse(BaseModel): model_config = ConfigDict( populate_by_name=True, @@ -8280,9 +7960,6 @@ class ClientRequest( | ThreadArchiveRequest | ThreadUnsubscribeRequest | ThreadNameSetRequest - | ThreadGoalSetRequest - | ThreadGoalGetRequest - | ThreadGoalClearRequest | ThreadMetadataUpdateRequest | ThreadUnarchiveRequest | ThreadCompactStartRequest @@ -8294,7 +7971,6 @@ class ClientRequest( | ThreadReadRequest | ThreadInjectItemsRequest | SkillsListRequest - | SkillsExtraRootsSetRequest | HooksListRequest | MarketplaceAddRequest | MarketplaceRemoveRequest @@ -8328,7 +8004,6 @@ class ClientRequest( | ModelListRequest | ModelProviderCapabilitiesReadRequest | ExperimentalFeatureListRequest - | PermissionProfileListRequest | ExperimentalFeatureEnablementSetRequest | McpServerOauthLoginRequest | ConfigMcpServerReloadRequest @@ -8368,9 +8043,6 @@ class ClientRequest( | ThreadArchiveRequest | ThreadUnsubscribeRequest | ThreadNameSetRequest - | ThreadGoalSetRequest - | ThreadGoalGetRequest - | ThreadGoalClearRequest | ThreadMetadataUpdateRequest | ThreadUnarchiveRequest | ThreadCompactStartRequest @@ -8382,7 +8054,6 @@ class ClientRequest( | ThreadReadRequest | ThreadInjectItemsRequest | SkillsListRequest - | SkillsExtraRootsSetRequest | HooksListRequest | MarketplaceAddRequest | MarketplaceRemoveRequest @@ -8416,7 +8087,6 @@ class ClientRequest( | ModelListRequest | ModelProviderCapabilitiesReadRequest | ExperimentalFeatureListRequest - | PermissionProfileListRequest | ExperimentalFeatureEnablementSetRequest | McpServerOauthLoginRequest | ConfigMcpServerReloadRequest @@ -8488,7 +8158,6 @@ class ServerNotification( | ThreadNameUpdatedServerNotification | ThreadGoalUpdatedServerNotification | ThreadGoalClearedServerNotification - | ThreadSettingsUpdatedServerNotification | ThreadTokenUsageUpdatedServerNotification | TurnStartedServerNotification | HookStartedServerNotification @@ -8558,7 +8227,6 @@ class ServerNotification( | ThreadNameUpdatedServerNotification | ThreadGoalUpdatedServerNotification | ThreadGoalClearedServerNotification - | ThreadSettingsUpdatedServerNotification | ThreadTokenUsageUpdatedServerNotification | TurnStartedServerNotification | HookStartedServerNotification diff --git a/sdk/python/tests/test_artifact_workflow_and_binaries.py b/sdk/python/tests/test_artifact_workflow_and_binaries.py index 501f1e100a7..d3f2f8d0ec1 100644 --- a/sdk/python/tests/test_artifact_workflow_and_binaries.py +++ b/sdk/python/tests/test_artifact_workflow_and_binaries.py @@ -14,18 +14,6 @@ ROOT = Path(__file__).resolve().parents[1] -def _load_root_format_script_module(): - """Load the root formatter driver so tests exercise its real command graph.""" - script_path = ROOT.parents[1] / "scripts" / "format.py" - spec = importlib.util.spec_from_file_location("format_repo", script_path) - if spec is None or spec.loader is None: - raise AssertionError(f"Failed to load script module: {script_path}") - module = importlib.util.module_from_spec(spec) - sys.modules[spec.name] = module - spec.loader.exec_module(module) - return module - - def _load_update_script_module(): """Load the maintenance script as a module so tests exercise real helpers.""" script_path = ROOT / "scripts" / "update_sdk_artifacts.py" @@ -80,135 +68,40 @@ def test_generation_has_single_maintenance_entrypoint_script() -> None: assert scripts == ["update_sdk_artifacts.py"] -def test_root_fmt_recipes_use_shared_formatter_driver() -> None: - """The root formatting recipes should use the shared cross-platform driver.""" +def test_root_fmt_recipe_formats_rust_and_python_sdk() -> None: + """The repo fmt command should work from Rust and Python SDK directories.""" justfile = ROOT.parents[1] / "justfile" lines = justfile.read_text().splitlines() fmt_index = lines.index("fmt:") - fmt_check_index = lines.index("fmt-check:") next_recipe_index = next( index - for index in range(fmt_check_index + 1, len(lines)) + for index in range(fmt_index + 1, len(lines)) if lines[index] and not lines[index].startswith((" ", "\t", "#")) ) + fmt_recipe = lines[fmt_index:next_recipe_index] actual = { "working_directory": lines[0], - "fmt_comment": next(line for line in reversed(lines[:fmt_index]) if line.startswith("#")), - "fmt_commands": [ - line.strip() - for line in lines[fmt_index + 1 : fmt_check_index] - if line.strip() and not line.startswith("#") - ], - "fmt_check_comment": next( - line for line in reversed(lines[:fmt_check_index]) if line.startswith("#") - ), - "fmt_check_commands": [ - line.strip() for line in lines[fmt_check_index + 1 : next_recipe_index] if line.strip() - ], + "previous_attribute": lines[fmt_index - 1], + "commands": [line.strip() for line in fmt_recipe[1:] if line.strip()], } expected = { "working_directory": 'set working-directory := "codex-rs"', - "fmt_comment": "# Format the justfile, Rust, Python SDK code, and Python scripts.", - "fmt_commands": ["{{ python }} ../scripts/format.py"], - "fmt_check_comment": "# Check formatting without modifying files.", - "fmt_check_commands": ["{{ python }} ../scripts/format.py --check"], + "previous_attribute": "# Format Rust and Python SDK code.", + "commands": [ + "cargo fmt -- --config imports_granularity=Item 2>/dev/null", + "uv run --frozen --project ../sdk/python --extra dev ruff check --fix --fix-only ../sdk/python", + "uv run --frozen --project ../sdk/python --extra dev ruff format ../sdk/python", + ], } assert actual == expected, ( - "The root formatting recipes must use the shared formatter driver. " - "Fix the recipes in `justfile`, then run `just fmt`.\n" + "The root `just fmt` recipe must run Rust fmt and Python SDK Ruff. " + "Fix the `fmt` recipe in `justfile`, then run `just fmt`.\n" f"Expected: {json.dumps(expected, indent=2)}\n" f"Actual: {json.dumps(actual, indent=2)}" ) -def test_root_format_driver_covers_all_formatter_groups() -> None: - """The shared driver should retain every formatter in both modes.""" - script = _load_root_format_script_module() - formatters = script.formatter_groups(check=False) - checks = script.formatter_groups(check=True) - - assert [group.name for group in formatters] == [ - "Just", - "Rust", - "Python SDK", - "Python scripts", - ] - assert [group.name for group in checks] == [group.name for group in formatters] - assert [len(group.commands) for group in formatters] == [1, 1, 2, 1] - assert [len(group.commands) for group in checks] == [ - len(group.commands) for group in formatters - ] - sdk_uv_run_args = ( - "uv", - "run", - "--frozen", - "--project", - "sdk/python", - "--no-sync", - "--with", - "ruff", - ) - scripts_uv_run_args = ( - "uv", - "run", - "--frozen", - "--project", - "scripts", - "--no-sync", - "--with", - "ruff", - ) - assert all( - command.args[: len(sdk_uv_run_args)] == sdk_uv_run_args - for group in (formatters[2], checks[2]) - for command in group.commands - ) - assert all( - command.args[: len(scripts_uv_run_args)] == scripts_uv_run_args - for group in (formatters[3], checks[3]) - for command in group.commands - ) - assert formatters[2].commands[0].args[-5:] == ( - "ruff", - "check", - "--fix", - "--fix-only", - "sdk/python", - ) - assert checks[2].commands[0].args[-4:] == ( - "ruff", - "check", - "--diff", - "sdk/python", - ) - assert formatters[0].commands[-1].args == ("just", "--unstable", "--fmt") - assert checks[0].commands[-1].args == ("just", "--unstable", "--fmt", "--check") - assert formatters[1].commands[-1].args == ( - "cargo", - "fmt", - "--", - "--config", - "imports_granularity=Item", - ) - assert checks[1].commands[-1].args == ( - "cargo", - "fmt", - "--", - "--config", - "imports_granularity=Item", - "--check", - ) - assert [group.commands[-1].args[-3:] for group in formatters[2:]] == [ - ("ruff", "format", "sdk/python"), - ("ruff", "format", "scripts"), - ] - assert [group.commands[-1].args[-4:] for group in checks[2:]] == [ - ("ruff", "format", "--check", "sdk/python"), - ("ruff", "format", "--check", "scripts"), - ] - - def test_generate_types_wires_all_generation_steps() -> None: """The type generation command should refresh every schema-derived artifact.""" source = (ROOT / "scripts" / "update_sdk_artifacts.py").read_text() @@ -261,13 +154,12 @@ def test_schema_normalization_only_flattens_string_literal_oneofs( assert flattened == [ "MessagePhase", "TurnItemsView", - "AuthMode", "PluginAvailability", + "AuthMode", "InputModality", "ExperimentalFeatureStage", "ProcessOutputStream", "CommandExecOutputStream", - "AutoCompactTokenLimitScope", ] @@ -358,10 +250,10 @@ def test_source_sdk_template_pins_published_runtime() -> None: "dependencies": pyproject["project"]["dependencies"], } == { "sdk_template_version": "0.0.0-dev", - "runtime_pin": "0.137.0a4", + "runtime_pin": "0.132.0", "dependencies": [ "pydantic>=2.12", - "openai-codex-cli-bin==0.137.0a4", + "openai-codex-cli-bin==0.132.0", ], } @@ -436,7 +328,7 @@ def test_runtime_setup_reads_independent_runtime_pin_and_release_tags() -> None: } == { "package_name": "openai-codex-cli-bin", "sdk_template_version": "0.0.0-dev", - "runtime_pin": "0.137.0a4", + "runtime_pin": "0.132.0", "normalized_release_version": "0.116.0a1", "release_tag": "rust-v0.116.0-alpha.1", } @@ -595,11 +487,11 @@ def test_stage_runtime_release_can_pin_wheel_platform_tag(tmp_path: Path) -> Non tmp_path / "runtime-stage", "0.116.0a1", package_archive, - platform_tag="manylinux_2_17_x86_64", + platform_tag="musllinux_1_1_x86_64", ) pyproject = (staged / "pyproject.toml").read_text() - assert 'platform-tag = "manylinux_2_17_x86_64"' in pyproject + assert 'platform-tag = "musllinux_1_1_x86_64"' in pyproject def test_stage_runtime_release_rejects_incomplete_package_layout(tmp_path: Path) -> None: @@ -651,7 +543,7 @@ def test_stage_sdk_release_preserves_reviewed_runtime_pin(tmp_path: Path) -> Non "version": "0.1.0b1", "dependencies": [ "pydantic>=2.12", - "openai-codex-cli-bin==0.137.0a4", + "openai-codex-cli-bin==0.132.0", ], } assert ( @@ -688,7 +580,7 @@ def test_sdk_beta_release_can_pin_stable_runtime(tmp_path: Path) -> None: ) runtime_stage = script.stage_python_runtime_package( tmp_path / "runtime-stage", - "0.137.0a4", + "0.132.0", package_archive, ) @@ -701,10 +593,10 @@ def test_sdk_beta_release_can_pin_stable_runtime(tmp_path: Path) -> None: "sdk_dependencies": sdk_pyproject["project"]["dependencies"], } == { "sdk_version": "0.1.0b1", - "runtime_version": "0.137.0a4", + "runtime_version": "0.132.0", "sdk_dependencies": [ "pydantic>=2.12", - "openai-codex-cli-bin==0.137.0a4", + "openai-codex-cli-bin==0.132.0", ], } @@ -763,7 +655,7 @@ def test_stage_runtime_stages_package_without_type_generation(tmp_path: Path) -> "--codex-version", "rust-v0.116.0-alpha.1", "--platform-tag", - "manylinux_2_17_x86_64", + "musllinux_1_1_x86_64", ] ) @@ -794,7 +686,7 @@ def fake_current_sdk_version() -> str: script.run_command(args, ops) - assert calls == ["stage_runtime:0.116.0a1:manylinux_2_17_x86_64:codex-package.tar.gz"] + assert calls == ["stage_runtime:0.116.0a1:musllinux_1_1_x86_64:codex-package.tar.gz"] def test_default_runtime_is_resolved_from_installed_runtime_package( diff --git a/sdk/python/tests/test_contract_generation.py b/sdk/python/tests/test_contract_generation.py index 36d37735c7a..2c95d06a3e7 100644 --- a/sdk/python/tests/test_contract_generation.py +++ b/sdk/python/tests/test_contract_generation.py @@ -40,7 +40,7 @@ def test_generated_files_are_up_to_date(): # Regenerate contract artifacts via the pinned runtime package, not a local # app-server binary from the checkout or CI environment. - assert importlib.metadata.version("openai-codex-cli-bin") == "0.137.0a4" + assert importlib.metadata.version("openai-codex-cli-bin") == "0.132.0" env = os.environ.copy() env.pop("CODEX_EXEC_PATH", None) python_bin = str(Path(sys.executable).parent) diff --git a/sdk/python/uv.lock b/sdk/python/uv.lock index 588c9f4fe2b..6d9e40e37a0 100644 --- a/sdk/python/uv.lock +++ b/sdk/python/uv.lock @@ -7,7 +7,7 @@ exclude-newer = "0001-01-01T00:00:00Z" # This has no effect and is included for exclude-newer-span = "P7D" [options.exclude-newer-package] -openai-codex-cli-bin = "2026-06-03T19:00:00Z" +openai-codex-cli-bin = "2026-05-20T21:00:00Z" [[package]] name = "annotated-types" @@ -299,7 +299,7 @@ dev = [ [package.metadata] requires-dist = [ { name = "datamodel-code-generator", marker = "extra == 'dev'", specifier = "==0.31.2" }, - { name = "openai-codex-cli-bin", specifier = "==0.137.0a4" }, + { name = "openai-codex-cli-bin", specifier = "==0.132.0" }, { name = "pydantic", specifier = ">=2.12" }, { name = "pytest", marker = "extra == 'dev'", specifier = ">=8.0" }, { name = "ruff", marker = "extra == 'dev'", specifier = ">=0.15.8" }, @@ -308,17 +308,15 @@ provides-extras = ["dev"] [[package]] name = "openai-codex-cli-bin" -version = "0.137.0a4" +version = "0.132.0" source = { registry = "https://pypi.org/simple" } wheels = [ - { url = "https://files.pythonhosted.org/packages/bd/60/af73ef1676cd477fa83ed4b889bf3b57c63c47dd87025b2cc4262793cff6/openai_codex_cli_bin-0.137.0a4-py3-none-macosx_10_9_x86_64.whl", hash = "sha256:b33c3917e0b58d527ee11a11a78ad390f7d8e6aa25577dd21665ab3c8bf5cf9a", size = 94300191, upload-time = "2026-06-03T18:44:36.312Z" }, - { url = "https://files.pythonhosted.org/packages/92/8f/d1a5f8c87176e00ef6a85798794f4530f5eb04e5a1a13468b5b3c3a361f9/openai_codex_cli_bin-0.137.0a4-py3-none-macosx_11_0_arm64.whl", hash = "sha256:3d0f0bc5becc88c61952fbfa9bd792ac9d74fa78b3a6bd40f545b612048b07eb", size = 83924479, upload-time = "2026-06-03T18:44:40.854Z" }, - { url = "https://files.pythonhosted.org/packages/3e/3c/fc00bcdc0c302208317d5eb1d0bfaab3024f351cd0121400f19baa6b19aa/openai_codex_cli_bin-0.137.0a4-py3-none-manylinux_2_17_aarch64.whl", hash = "sha256:2f1656339e2736868c4cce59f6d9e5c633879123687169b03b1137d42bf2c11a", size = 83363315, upload-time = "2026-06-03T18:44:44.851Z" }, - { url = "https://files.pythonhosted.org/packages/ec/09/39362e944ebeb12fcbfb86881fbb4dd6e806f77f7541c1f1f993bb9351a0/openai_codex_cli_bin-0.137.0a4-py3-none-manylinux_2_17_x86_64.whl", hash = "sha256:6454f838d44c56c1ed07a29b391fa412785e5dd2ffd06db0b62e62478c19bb64", size = 90611239, upload-time = "2026-06-03T18:44:49.338Z" }, - { url = "https://files.pythonhosted.org/packages/fa/38/87b1247fdfe95cddce7f7fe8331d6843cf037e14292c0f5004e23247133b/openai_codex_cli_bin-0.137.0a4-py3-none-musllinux_1_1_aarch64.whl", hash = "sha256:f5ae7401d00c65d56a75d9645d7bf87d809566a12d238e4b2a8b328a02f2316e", size = 83363315, upload-time = "2026-06-03T18:44:53.428Z" }, - { url = "https://files.pythonhosted.org/packages/fb/c4/3c693ad07e587f6b3a28128c417f2e831d81a40cdbd85c0e5f0f36aaff82/openai_codex_cli_bin-0.137.0a4-py3-none-musllinux_1_1_x86_64.whl", hash = "sha256:3dcec1e649448be498d6e7ec0e1f71dca83efa76063d90890dafb41e987069b7", size = 90611238, upload-time = "2026-06-03T18:44:57.612Z" }, - { url = "https://files.pythonhosted.org/packages/9e/26/81e037066b9b8d312a6f9e09015e452ce17630d5ab88e02a4c1d9503e4e8/openai_codex_cli_bin-0.137.0a4-py3-none-win_amd64.whl", hash = "sha256:9e13bf68e18e36bd3a0efd51213281c83e9f6ec22bdb7a45bd2e0211822733a9", size = 94744969, upload-time = "2026-06-03T18:45:02.23Z" }, - { url = "https://files.pythonhosted.org/packages/0d/a3/952bc2a5d62373a51fea161effe3b338b3417c2f6e65fe467ed91b205e2b/openai_codex_cli_bin-0.137.0a4-py3-none-win_arm64.whl", hash = "sha256:5ec4303ca2dcb5f838e0de3ca7f44050b6bcdd41d281a178c3a1420a985a515d", size = 86963504, upload-time = "2026-06-03T18:45:07.131Z" }, + { url = "https://files.pythonhosted.org/packages/be/a1/b92b7a1b73a83785d2e1dcd0faecd1b7f886a38cf02a30abe1c35f42f0f7/openai_codex_cli_bin-0.132.0-py3-none-macosx_10_9_x86_64.whl", hash = "sha256:1c22b51dbd679413f84f00b9d8fd4e5cf8a1c0d1c7cc8c42bcb3f9f1b33e2334", size = 89403211, upload-time = "2026-05-20T02:37:22.311Z" }, + { url = "https://files.pythonhosted.org/packages/5f/68/163272e582de55a7f460e2329281267908d75d0fbcbbbb2c6749a6329e6b/openai_codex_cli_bin-0.132.0-py3-none-macosx_11_0_arm64.whl", hash = "sha256:56217495e6635c8a5d96df820cc0da5f46cd9b6ec6f3a5f67f1607d69ef74256", size = 79058685, upload-time = "2026-05-20T02:37:27.165Z" }, + { url = "https://files.pythonhosted.org/packages/0b/18/a60c6b137e7cd3959cae16ba757f57ca5702979b0ea107a21f516ba15d98/openai_codex_cli_bin-0.132.0-py3-none-musllinux_1_1_aarch64.whl", hash = "sha256:09642e7578a3078bfccc82af4077b085d42b0022b529e4b5c645e0a0af3397a4", size = 78689038, upload-time = "2026-05-20T02:37:31.548Z" }, + { url = "https://files.pythonhosted.org/packages/f8/eb/1b184307a67c1006d59b61636bcfcea73a89aa95271f6516ed28dce554ca/openai_codex_cli_bin-0.132.0-py3-none-musllinux_1_1_x86_64.whl", hash = "sha256:85aec095f9d144d7a2d1aff39fce77b7240f42014580c35801ba74b9317aa5f7", size = 85528820, upload-time = "2026-05-20T02:37:36.559Z" }, + { url = "https://files.pythonhosted.org/packages/0e/e8/1b823a8bf7b96d1513905ad79b16a146d797f81a19a6bc350a2f95a16661/openai_codex_cli_bin-0.132.0-py3-none-win_amd64.whl", hash = "sha256:3cb5c90c55baa39bd5ddc890d2068d3e1322a57a54d1d0e623819009a205c7f5", size = 86916218, upload-time = "2026-05-20T02:37:41.886Z" }, + { url = "https://files.pythonhosted.org/packages/6b/e6/bb8634bd4f3adaea299c95d7b03105ac417e32dd6d8bc2af5dda141d6f28/openai_codex_cli_bin-0.132.0-py3-none-win_arm64.whl", hash = "sha256:74ef93d3deef7cb83c71d19fc667defe749cdab337ec331f59a23511561b6f6a", size = 79892931, upload-time = "2026-05-20T02:37:46.828Z" }, ] [[package]] From 729adaf207b1145cd234f98abcf107db4a624f3b Mon Sep 17 00:00:00 2001 From: Sama Setty Date: Thu, 4 Jun 2026 18:58:16 -0700 Subject: [PATCH 11/15] codex: address retry review structure feedback (#25147) --- codex-rs/rmcp-client/src/rmcp_client.rs | 403 +---------------- .../src/rmcp_client/streamable_http_retry.rs | 411 ++++++++++++++++++ .../streamable_http_retry_tests.rs | 93 ++++ codex-rs/rmcp-client/src/rmcp_client_tests.rs | 86 ---- justfile | 16 +- 5 files changed, 516 insertions(+), 493 deletions(-) create mode 100644 codex-rs/rmcp-client/src/rmcp_client/streamable_http_retry.rs create mode 100644 codex-rs/rmcp-client/src/rmcp_client/streamable_http_retry_tests.rs diff --git a/codex-rs/rmcp-client/src/rmcp_client.rs b/codex-rs/rmcp-client/src/rmcp_client.rs index c909c88bdaf..b2f764b9cae 100644 --- a/codex-rs/rmcp-client/src/rmcp_client.rs +++ b/codex-rs/rmcp-client/src/rmcp_client.rs @@ -14,7 +14,6 @@ use anyhow::anyhow; use codex_api::SharedAuthProvider; use codex_client::maybe_build_rustls_client_config_with_custom_ca; use codex_config::types::McpServerEnvVar; -use codex_exec_server::ExecServerError; use codex_exec_server::HttpClient; use futures::FutureExt; use futures::future::BoxFuture; @@ -50,7 +49,6 @@ use rmcp::transport::auth::AuthClient; use rmcp::transport::auth::AuthError; use rmcp::transport::auth::OAuthState; use rmcp::transport::streamable_http_client::StreamableHttpClientTransportConfig; -use rmcp::transport::streamable_http_client::StreamableHttpError; use serde::Deserialize; use serde::Serialize; use serde_json::Value; @@ -62,7 +60,6 @@ use tracing::warn; use crate::elicitation_client_service::ElicitationClientService; use crate::http_client_adapter::StreamableHttpClientAdapter; -use crate::http_client_adapter::StreamableHttpClientAdapterError; use crate::in_process_transport::InProcessTransportFactory; use crate::load_oauth_tokens; use crate::oauth::OAuthPersistor; @@ -75,9 +72,9 @@ use crate::utils::apply_default_headers; use crate::utils::build_default_headers; use codex_config::types::OAuthCredentialsStoreMode; -const MCP_CLIENT_INITIALIZE_METRIC: &str = "codex.mcp.client.initialize"; -const JSON_RPC_INTERNAL_ERROR_CODE: i64 = -32603; -const STREAMABLE_HTTP_RETRY_DELAYS_MS: [u64; 2] = [250, 1_000]; +mod streamable_http_retry; + +use self::streamable_http_retry::HandshakeError; enum PendingTransport { InProcess { @@ -95,63 +92,6 @@ enum PendingTransport { }, } -impl PendingTransport { - fn is_streamable_http(&self) -> bool { - matches!( - self, - PendingTransport::StreamableHttp { .. } - | PendingTransport::StreamableHttpWithOAuth { .. } - ) - } - - fn metric_transport(&self) -> &'static str { - match self { - PendingTransport::InProcess { .. } => "in_process", - PendingTransport::Stdio { .. } => "stdio", - PendingTransport::StreamableHttp { .. } - | PendingTransport::StreamableHttpWithOAuth { .. } => "streamable_http", - } - } -} - -fn initialize_metric_tags( - transport: &'static str, - outcome: &'static str, - attempts: usize, - retry_exhausted: bool, - failure_kind: &'static str, -) -> Vec<(&'static str, String)> { - let attempts = attempts.max(1); - vec![ - ("transport", transport.to_string()), - ("outcome", outcome.to_string()), - ("retried", (attempts > 1).to_string()), - ("attempts", attempts.to_string()), - ("retry_count", attempts.saturating_sub(1).to_string()), - ("retry_exhausted", retry_exhausted.to_string()), - ("failure_kind", failure_kind.to_string()), - ] -} - -fn emit_initialize_metric( - transport: &'static str, - outcome: &'static str, - attempts: usize, - retry_exhausted: bool, - failure_kind: &'static str, -) { - let Some(metrics) = codex_otel::global() else { - return; - }; - - let tags = initialize_metric_tags(transport, outcome, attempts, retry_exhausted, failure_kind); - let tag_refs: Vec<(&str, &str)> = tags - .iter() - .map(|(key, value)| (*key, value.as_str())) - .collect(); - let _ = metrics.counter(MCP_CLIENT_INITIALIZE_METRIC, /*inc*/ 1, &tag_refs); -} - enum ClientState { Connecting { transport: Option, @@ -277,34 +217,6 @@ where } } -async fn sleep_with_retry_deadline(delay: Duration, deadline: Option) -> bool { - if let Some(deadline) = deadline { - let remaining = deadline.saturating_duration_since(Instant::now()); - if remaining.is_zero() { - return false; - } - time::timeout(remaining, time::sleep(delay)).await.is_ok() - } else { - time::sleep(delay).await; - true - } -} - -#[derive(Debug, thiserror::Error)] -enum ClientOperationError { - #[error(transparent)] - Service(#[from] rmcp::service::ServiceError), - #[error("timed out awaiting {label} after {duration:?}")] - Timeout { label: String, duration: Duration }, -} - -#[derive(Debug, thiserror::Error)] -#[error("handshaking with MCP server failed: {source}")] -struct HandshakeError { - #[source] - source: rmcp::service::ClientInitializeError, -} - pub type Elicitation = CreateElicitationRequestParams; #[derive(Debug, Clone, PartialEq, Serialize, Deserialize)] @@ -944,315 +856,6 @@ impl RmcpClient { Ok((Arc::new(service), oauth_persistor)) } - async fn connect_pending_transport_with_initialize_retries( - &self, - initial_transport: PendingTransport, - client_service: ElicitationClientService, - timeout: Option, - ) -> Result<( - Arc>, - Option, - )> { - let should_retry = initial_transport.is_streamable_http(); - let metric_transport = initial_transport.metric_transport(); - let retry_deadline = timeout.map(|duration| Instant::now() + duration); - let mut pending_transport = Some(initial_transport); - - let retry_schedule = STREAMABLE_HTTP_RETRY_DELAYS_MS - .iter() - .copied() - .map(Some) - .chain(std::iter::once(None)); - - for (attempt, retry_delay_ms) in retry_schedule.enumerate() { - let attempt_count = attempt + 1; - let attempt_timeout = - retry_deadline.map(|deadline| deadline.saturating_duration_since(Instant::now())); - if let Some(remaining) = attempt_timeout - && remaining.is_zero() - { - emit_initialize_metric( - metric_transport, - "error", - attempt_count, - /*retry_exhausted*/ false, - "timeout", - ); - let duration = timeout.unwrap_or(remaining); - return Err(anyhow!( - "timed out handshaking with MCP server after {duration:?}" - )); - } - - let transport = match pending_transport.take() { - Some(transport) => transport, - None => match Self::create_pending_transport(&self.transport_recipe).await { - Ok(transport) => transport, - Err(error) => { - emit_initialize_metric( - metric_transport, - "error", - attempt_count, - /*retry_exhausted*/ false, - "transport_create", - ); - return Err(error); - } - }, - }; - - match Self::connect_pending_transport( - transport, - client_service.clone(), - attempt_timeout, - ) - .await - { - Ok(result) => { - emit_initialize_metric( - metric_transport, - "success", - attempt_count, - /*retry_exhausted*/ false, - "none", - ); - return Ok(result); - } - Err(error) if should_retry && Self::is_retryable_initialize_error(&error) => { - let Some(retry_delay_ms) = retry_delay_ms else { - emit_initialize_metric( - metric_transport, - "error", - attempt_count, - /*retry_exhausted*/ true, - "retry_exhausted", - ); - return Err(error); - }; - let delay = Duration::from_millis(retry_delay_ms); - warn!( - attempt = attempt_count, - max_attempts = STREAMABLE_HTTP_RETRY_DELAYS_MS.len() + 1, - delay_ms = delay.as_millis(), - error = %error, - "streamable HTTP MCP initialize failed with a retryable error; retrying" - ); - if !sleep_with_retry_deadline(delay, retry_deadline).await { - emit_initialize_metric( - metric_transport, - "error", - attempt_count, - /*retry_exhausted*/ false, - "timeout", - ); - let duration = timeout.unwrap_or(delay); - return Err(anyhow!( - "timed out handshaking with MCP server after {duration:?}" - )); - } - } - Err(error) => { - emit_initialize_metric( - metric_transport, - "error", - attempt_count, - /*retry_exhausted*/ false, - "non_retryable", - ); - return Err(error); - } - } - } - - unreachable!("initialize retry loop should return on success or final error") - } - - async fn run_service_operation( - &self, - label: &str, - timeout: Option, - operation: F, - ) -> Result - where - F: Fn(Arc>) -> Fut, - Fut: std::future::Future>, - { - let mut session_recovery_attempted = false; - let mut retry_attempt = 0; - let retry_deadline = timeout.map(|duration| Instant::now() + duration); - - loop { - let service = self.service().await?; - let attempt_timeout = - retry_deadline.map(|deadline| deadline.saturating_duration_since(Instant::now())); - if let Some(remaining) = attempt_timeout - && remaining.is_zero() - { - let duration = timeout.unwrap_or(remaining); - return Err(ClientOperationError::Timeout { - label: label.to_string(), - duration, - } - .into()); - } - - match Self::run_service_operation_once( - Arc::clone(&service), - label, - attempt_timeout, - self.elicitation_pause_state.clone(), - &operation, - ) - .await - { - Ok(result) => return Ok(result), - Err(error) - if !session_recovery_attempted && Self::is_session_expired_404(&error) => - { - session_recovery_attempted = true; - self.reinitialize_after_session_expiry(&service).await?; - } - Err(error) - if Self::should_retry_tools_list_operation(label, retry_attempt, &error) => - { - let delay = - Duration::from_millis(STREAMABLE_HTTP_RETRY_DELAYS_MS[retry_attempt]); - retry_attempt += 1; - warn!( - label, - attempt = retry_attempt, - max_attempts = STREAMABLE_HTTP_RETRY_DELAYS_MS.len() + 1, - delay_ms = delay.as_millis(), - error = %error, - "MCP service operation failed with a retryable error; retrying" - ); - if !sleep_with_retry_deadline(delay, retry_deadline).await { - let duration = timeout.unwrap_or(delay); - return Err(ClientOperationError::Timeout { - label: label.to_string(), - duration, - } - .into()); - } - } - Err(error) => return Err(error.into()), - } - } - } - - async fn run_service_operation_once( - service: Arc>, - label: &str, - timeout: Option, - pause_state: ElicitationPauseState, - operation: &F, - ) -> std::result::Result - where - F: Fn(Arc>) -> Fut, - Fut: std::future::Future>, - { - match timeout { - Some(duration) => { - active_time_timeout(duration, pause_state.subscribe(), operation(service)) - .await - .map_err(|_| ClientOperationError::Timeout { - label: label.to_string(), - duration, - })? - .map_err(ClientOperationError::from) - } - None => operation(service).await.map_err(ClientOperationError::from), - } - } - - fn is_session_expired_404(error: &ClientOperationError) -> bool { - let ClientOperationError::Service(rmcp::service::ServiceError::TransportSend(error)) = - error - else { - return false; - }; - - error - .error - .downcast_ref::>() - .is_some_and(|error| { - matches!( - error, - StreamableHttpError::Client( - StreamableHttpClientAdapterError::SessionExpired404 - ) - ) - }) - } - - fn should_retry_tools_list_operation( - label: &str, - retry_attempt: usize, - error: &ClientOperationError, - ) -> bool { - label == "tools/list" - && retry_attempt < STREAMABLE_HTTP_RETRY_DELAYS_MS.len() - && Self::is_retryable_service_operation_error(error) - } - - fn is_retryable_service_operation_error(error: &ClientOperationError) -> bool { - let ClientOperationError::Service(rmcp::service::ServiceError::TransportSend(error)) = - error - else { - return false; - }; - - error - .error - .downcast_ref::>() - .is_some_and(Self::is_retryable_streamable_http_error) - } - - fn is_retryable_initialize_error(error: &anyhow::Error) -> bool { - error.chain().any(|source| { - source - .downcast_ref::() - .is_some_and(|error| Self::is_retryable_client_initialize_error(&error.source)) - || source - .downcast_ref::() - .is_some_and(Self::is_retryable_client_initialize_error) - }) - } - - fn is_retryable_client_initialize_error(error: &rmcp::service::ClientInitializeError) -> bool { - match error { - rmcp::service::ClientInitializeError::TransportError { error, context } - if matches!( - context.as_ref(), - "send initialize request" | "send initialized notification" - ) => - { - error - .error - .downcast_ref::>() - .is_some_and(Self::is_retryable_streamable_http_error) - } - _ => false, - } - } - - fn is_retryable_streamable_http_error( - error: &StreamableHttpError, - ) -> bool { - match error { - StreamableHttpError::Client( - StreamableHttpClientAdapterError::RetryableHttpStatus(_) - | StreamableHttpClientAdapterError::HttpRequest(ExecServerError::HttpRequest(_)), - ) => true, - StreamableHttpError::Client(StreamableHttpClientAdapterError::HttpRequest( - ExecServerError::Server { code, message }, - )) => { - *code == JSON_RPC_INTERNAL_ERROR_CODE && message.starts_with("http/request failed:") - } - _ => false, - } - } - async fn reinitialize_after_session_expiry( &self, failed_service: &Arc>, diff --git a/codex-rs/rmcp-client/src/rmcp_client/streamable_http_retry.rs b/codex-rs/rmcp-client/src/rmcp_client/streamable_http_retry.rs new file mode 100644 index 00000000000..42cd225d0a9 --- /dev/null +++ b/codex-rs/rmcp-client/src/rmcp_client/streamable_http_retry.rs @@ -0,0 +1,411 @@ +use std::future::Future; +use std::sync::Arc; +use std::time::Duration; +use std::time::Instant; + +use anyhow::Result; +use anyhow::anyhow; +use codex_exec_server::ExecServerError; +use rmcp::service::RoleClient; +use rmcp::service::RunningService; +use rmcp::transport::streamable_http_client::StreamableHttpError; +use tokio::time; +use tracing::warn; + +use crate::elicitation_client_service::ElicitationClientService; +use crate::http_client_adapter::StreamableHttpClientAdapterError; +use crate::oauth::OAuthPersistor; + +use super::ElicitationPauseState; +use super::PendingTransport; +use super::RmcpClient; +use super::active_time_timeout; + +const MCP_CLIENT_INITIALIZE_METRIC: &str = "codex.mcp.client.initialize"; +const JSON_RPC_INTERNAL_ERROR_CODE: i64 = -32603; +const STREAMABLE_HTTP_RETRY_DELAYS_MS: [u64; 2] = [250, 1_000]; + +impl RmcpClient { + pub(super) async fn connect_pending_transport_with_initialize_retries( + &self, + initial_transport: PendingTransport, + client_service: ElicitationClientService, + timeout: Option, + ) -> Result<( + Arc>, + Option, + )> { + let (should_retry, metric_transport) = match &initial_transport { + PendingTransport::InProcess { .. } => (false, "in_process"), + PendingTransport::Stdio { .. } => (false, "stdio"), + PendingTransport::StreamableHttp { .. } + | PendingTransport::StreamableHttpWithOAuth { .. } => (true, "streamable_http"), + }; + let retry_deadline = timeout.map(|duration| Instant::now() + duration); + let mut pending_transport = Some(initial_transport); + + let retry_schedule = STREAMABLE_HTTP_RETRY_DELAYS_MS + .iter() + .copied() + .map(Some) + .chain(std::iter::once(None)); + + for (attempt, retry_delay_ms) in retry_schedule.enumerate() { + let attempt_count = attempt + 1; + let attempt_timeout = + retry_deadline.map(|deadline| deadline.saturating_duration_since(Instant::now())); + if let Some(remaining) = attempt_timeout + && remaining.is_zero() + { + emit_initialize_metric( + metric_transport, + "error", + attempt_count, + /*retry_exhausted*/ false, + "timeout", + ); + let duration = timeout.unwrap_or(remaining); + return Err(anyhow!( + "timed out handshaking with MCP server after {duration:?}" + )); + } + + let transport = match pending_transport.take() { + Some(transport) => transport, + None => match Self::create_pending_transport(&self.transport_recipe).await { + Ok(transport) => transport, + Err(error) => { + emit_initialize_metric( + metric_transport, + "error", + attempt_count, + /*retry_exhausted*/ false, + "transport_create", + ); + return Err(error); + } + }, + }; + + match Self::connect_pending_transport( + transport, + client_service.clone(), + attempt_timeout, + ) + .await + { + Ok(result) => { + emit_initialize_metric( + metric_transport, + "success", + attempt_count, + /*retry_exhausted*/ false, + "none", + ); + return Ok(result); + } + Err(error) if should_retry && Self::is_retryable_initialize_error(&error) => { + let Some(retry_delay_ms) = retry_delay_ms else { + emit_initialize_metric( + metric_transport, + "error", + attempt_count, + /*retry_exhausted*/ true, + "retry_exhausted", + ); + return Err(error); + }; + let delay = Duration::from_millis(retry_delay_ms); + warn!( + attempt = attempt_count, + max_attempts = STREAMABLE_HTTP_RETRY_DELAYS_MS.len() + 1, + delay_ms = delay.as_millis(), + error = %error, + "streamable HTTP MCP initialize failed with a retryable error; retrying" + ); + if !sleep_with_retry_deadline(delay, retry_deadline).await { + emit_initialize_metric( + metric_transport, + "error", + attempt_count, + /*retry_exhausted*/ false, + "timeout", + ); + let duration = timeout.unwrap_or(delay); + return Err(anyhow!( + "timed out handshaking with MCP server after {duration:?}" + )); + } + } + Err(error) => { + emit_initialize_metric( + metric_transport, + "error", + attempt_count, + /*retry_exhausted*/ false, + "non_retryable", + ); + return Err(error); + } + } + } + + unreachable!("initialize retry loop should return on success or final error") + } + + pub(super) async fn run_service_operation( + &self, + label: &str, + timeout: Option, + operation: F, + ) -> Result + where + F: Fn(Arc>) -> Fut, + Fut: Future>, + { + let mut session_recovery_attempted = false; + let mut retry_attempt = 0; + let retry_deadline = timeout.map(|duration| Instant::now() + duration); + + loop { + let service = self.service().await?; + let attempt_timeout = + retry_deadline.map(|deadline| deadline.saturating_duration_since(Instant::now())); + if let Some(remaining) = attempt_timeout + && remaining.is_zero() + { + let duration = timeout.unwrap_or(remaining); + return Err(ClientOperationError::Timeout { + label: label.to_string(), + duration, + } + .into()); + } + + match Self::run_service_operation_once( + Arc::clone(&service), + label, + attempt_timeout, + self.elicitation_pause_state.clone(), + &operation, + ) + .await + { + Ok(result) => return Ok(result), + Err(error) + if !session_recovery_attempted && Self::is_session_expired_404(&error) => + { + session_recovery_attempted = true; + self.reinitialize_after_session_expiry(&service).await?; + } + Err(error) + if Self::should_retry_tools_list_operation(label, retry_attempt, &error) => + { + let delay = + Duration::from_millis(STREAMABLE_HTTP_RETRY_DELAYS_MS[retry_attempt]); + retry_attempt += 1; + warn!( + label, + attempt = retry_attempt, + max_attempts = STREAMABLE_HTTP_RETRY_DELAYS_MS.len() + 1, + delay_ms = delay.as_millis(), + error = %error, + "MCP service operation failed with a retryable error; retrying" + ); + if !sleep_with_retry_deadline(delay, retry_deadline).await { + let duration = timeout.unwrap_or(delay); + return Err(ClientOperationError::Timeout { + label: label.to_string(), + duration, + } + .into()); + } + } + Err(error) => return Err(error.into()), + } + } + } + + async fn run_service_operation_once( + service: Arc>, + label: &str, + timeout: Option, + pause_state: ElicitationPauseState, + operation: &F, + ) -> std::result::Result + where + F: Fn(Arc>) -> Fut, + Fut: Future>, + { + match timeout { + Some(duration) => { + active_time_timeout(duration, pause_state.subscribe(), operation(service)) + .await + .map_err(|_| ClientOperationError::Timeout { + label: label.to_string(), + duration, + })? + .map_err(ClientOperationError::from) + } + None => operation(service).await.map_err(ClientOperationError::from), + } + } + + fn is_session_expired_404(error: &ClientOperationError) -> bool { + let ClientOperationError::Service(rmcp::service::ServiceError::TransportSend(error)) = + error + else { + return false; + }; + + error + .error + .downcast_ref::>() + .is_some_and(|error| { + matches!( + error, + StreamableHttpError::Client( + StreamableHttpClientAdapterError::SessionExpired404 + ) + ) + }) + } + + fn should_retry_tools_list_operation( + label: &str, + retry_attempt: usize, + error: &ClientOperationError, + ) -> bool { + label == "tools/list" + && retry_attempt < STREAMABLE_HTTP_RETRY_DELAYS_MS.len() + && Self::is_retryable_service_operation_error(error) + } + + fn is_retryable_service_operation_error(error: &ClientOperationError) -> bool { + let ClientOperationError::Service(rmcp::service::ServiceError::TransportSend(error)) = + error + else { + return false; + }; + + error + .error + .downcast_ref::>() + .is_some_and(Self::is_retryable_streamable_http_error) + } + + fn is_retryable_initialize_error(error: &anyhow::Error) -> bool { + error.chain().any(|source| { + source + .downcast_ref::() + .is_some_and(|error| Self::is_retryable_client_initialize_error(&error.source)) + || source + .downcast_ref::() + .is_some_and(Self::is_retryable_client_initialize_error) + }) + } + + fn is_retryable_client_initialize_error(error: &rmcp::service::ClientInitializeError) -> bool { + match error { + rmcp::service::ClientInitializeError::TransportError { error, context } + if matches!( + context.as_ref(), + "send initialize request" | "send initialized notification" + ) => + { + error + .error + .downcast_ref::>() + .is_some_and(Self::is_retryable_streamable_http_error) + } + _ => false, + } + } + + fn is_retryable_streamable_http_error( + error: &StreamableHttpError, + ) -> bool { + match error { + StreamableHttpError::Client( + StreamableHttpClientAdapterError::RetryableHttpStatus(_) + | StreamableHttpClientAdapterError::HttpRequest(ExecServerError::HttpRequest(_)), + ) => true, + StreamableHttpError::Client(StreamableHttpClientAdapterError::HttpRequest( + ExecServerError::Server { code, message }, + )) => { + *code == JSON_RPC_INTERNAL_ERROR_CODE && message.starts_with("http/request failed:") + } + _ => false, + } + } +} + +async fn sleep_with_retry_deadline(delay: Duration, deadline: Option) -> bool { + if let Some(deadline) = deadline { + let remaining = deadline.saturating_duration_since(Instant::now()); + if remaining.is_zero() { + return false; + } + time::timeout(remaining, time::sleep(delay)).await.is_ok() + } else { + time::sleep(delay).await; + true + } +} + +fn initialize_metric_tags( + transport: &'static str, + outcome: &'static str, + attempts: usize, + retry_exhausted: bool, + failure_kind: &'static str, +) -> Vec<(&'static str, String)> { + let attempts = attempts.max(1); + vec![ + ("transport", transport.to_string()), + ("outcome", outcome.to_string()), + ("retried", (attempts > 1).to_string()), + ("attempts", attempts.to_string()), + ("retry_count", attempts.saturating_sub(1).to_string()), + ("retry_exhausted", retry_exhausted.to_string()), + ("failure_kind", failure_kind.to_string()), + ] +} + +fn emit_initialize_metric( + transport: &'static str, + outcome: &'static str, + attempts: usize, + retry_exhausted: bool, + failure_kind: &'static str, +) { + let Some(metrics) = codex_otel::global() else { + return; + }; + + let tags = initialize_metric_tags(transport, outcome, attempts, retry_exhausted, failure_kind); + let tag_refs: Vec<(&str, &str)> = tags + .iter() + .map(|(key, value)| (*key, value.as_str())) + .collect(); + let _ = metrics.counter(MCP_CLIENT_INITIALIZE_METRIC, /*inc*/ 1, &tag_refs); +} + +#[derive(Debug, thiserror::Error)] +enum ClientOperationError { + #[error(transparent)] + Service(#[from] rmcp::service::ServiceError), + #[error("timed out awaiting {label} after {duration:?}")] + Timeout { label: String, duration: Duration }, +} + +#[derive(Debug, thiserror::Error)] +#[error("handshaking with MCP server failed: {source}")] +pub(super) struct HandshakeError { + #[source] + pub(super) source: rmcp::service::ClientInitializeError, +} + +#[cfg(test)] +#[path = "streamable_http_retry_tests.rs"] +mod tests; diff --git a/codex-rs/rmcp-client/src/rmcp_client/streamable_http_retry_tests.rs b/codex-rs/rmcp-client/src/rmcp_client/streamable_http_retry_tests.rs new file mode 100644 index 00000000000..9a58c27c0d9 --- /dev/null +++ b/codex-rs/rmcp-client/src/rmcp_client/streamable_http_retry_tests.rs @@ -0,0 +1,93 @@ +use std::any::TypeId; +use std::collections::BTreeMap; + +use pretty_assertions::assert_eq; +use rmcp::transport::DynamicTransportError; +use rmcp::transport::streamable_http_client::StreamableHttpError; + +use crate::http_client_adapter::StreamableHttpClientAdapterError; + +use super::*; + +fn metric_tags_map(tags: Vec<(&'static str, String)>) -> BTreeMap<&'static str, String> { + tags.into_iter().collect() +} + +#[test] +fn initialize_metric_tags_record_success_after_retry() { + let tags = metric_tags_map(initialize_metric_tags( + "streamable_http", + "success", + /*attempts*/ 2, + /*retry_exhausted*/ false, + "none", + )); + + assert_eq!( + tags, + metric_tags_map(vec![ + ("transport", "streamable_http".to_string()), + ("outcome", "success".to_string()), + ("retried", "true".to_string()), + ("attempts", "2".to_string()), + ("retry_count", "1".to_string()), + ("retry_exhausted", "false".to_string()), + ("failure_kind", "none".to_string()), + ]) + ); +} + +#[test] +fn initialize_metric_tags_record_retry_exhaustion() { + let tags = metric_tags_map(initialize_metric_tags( + "streamable_http", + "error", + /*attempts*/ 3, + /*retry_exhausted*/ true, + "retry_exhausted", + )); + + assert_eq!( + tags, + metric_tags_map(vec![ + ("transport", "streamable_http".to_string()), + ("outcome", "error".to_string()), + ("retried", "true".to_string()), + ("attempts", "3".to_string()), + ("retry_count", "2".to_string()), + ("retry_exhausted", "true".to_string()), + ("failure_kind", "retry_exhausted".to_string()), + ]) + ); +} + +#[test] +fn retryable_initialize_error_includes_initialized_notification_context() { + let contexts = [ + "send initialize request", + "send initialized notification", + "receive initialize response", + ]; + + assert_eq!( + contexts.map(|context| { + RmcpClient::is_retryable_client_initialize_error(&retryable_initialize_error(context)) + }), + [true, true, false], + ); +} + +fn retryable_initialize_error(context: &'static str) -> rmcp::service::ClientInitializeError { + rmcp::service::ClientInitializeError::TransportError { + error: DynamicTransportError::from_parts( + "streamable_http", + TypeId::of::<()>(), + Box::new(StreamableHttpError::Client( + StreamableHttpClientAdapterError::RetryableHttpStatus( + reqwest::StatusCode::SERVICE_UNAVAILABLE.as_u16(), + ), + )), + ), + context: context.into(), + } +} diff --git a/codex-rs/rmcp-client/src/rmcp_client_tests.rs b/codex-rs/rmcp-client/src/rmcp_client_tests.rs index f8689a73391..f371acde0de 100644 --- a/codex-rs/rmcp-client/src/rmcp_client_tests.rs +++ b/codex-rs/rmcp-client/src/rmcp_client_tests.rs @@ -1,17 +1,10 @@ -use std::any::TypeId; -use std::collections::BTreeMap; use std::time::Duration; use pretty_assertions::assert_eq; -use rmcp::transport::DynamicTransportError; use tokio::time; use super::*; -fn metric_tags_map(tags: Vec<(&'static str, String)>) -> BTreeMap<&'static str, String> { - tags.into_iter().collect() -} - #[tokio::test] async fn active_time_timeout_pauses_while_elicitation_is_pending() { let pause_state = ElicitationPauseState::new(); @@ -29,82 +22,3 @@ async fn active_time_timeout_pauses_while_elicitation_is_pending() { assert_eq!(Ok("done"), result); } - -#[test] -fn initialize_metric_tags_record_success_after_retry() { - let tags = metric_tags_map(initialize_metric_tags( - "streamable_http", - "success", - /*attempts*/ 2, - /*retry_exhausted*/ false, - "none", - )); - - assert_eq!( - tags, - metric_tags_map(vec![ - ("transport", "streamable_http".to_string()), - ("outcome", "success".to_string()), - ("retried", "true".to_string()), - ("attempts", "2".to_string()), - ("retry_count", "1".to_string()), - ("retry_exhausted", "false".to_string()), - ("failure_kind", "none".to_string()), - ]) - ); -} - -#[test] -fn initialize_metric_tags_record_retry_exhaustion() { - let tags = metric_tags_map(initialize_metric_tags( - "streamable_http", - "error", - /*attempts*/ 3, - /*retry_exhausted*/ true, - "retry_exhausted", - )); - - assert_eq!( - tags, - metric_tags_map(vec![ - ("transport", "streamable_http".to_string()), - ("outcome", "error".to_string()), - ("retried", "true".to_string()), - ("attempts", "3".to_string()), - ("retry_count", "2".to_string()), - ("retry_exhausted", "true".to_string()), - ("failure_kind", "retry_exhausted".to_string()), - ]) - ); -} - -#[test] -fn retryable_initialize_error_includes_initialized_notification_context() { - let contexts = [ - "send initialize request", - "send initialized notification", - "receive initialize response", - ]; - - assert_eq!( - contexts.map(|context| { - RmcpClient::is_retryable_client_initialize_error(&retryable_initialize_error(context)) - }), - [true, true, false], - ); -} - -fn retryable_initialize_error(context: &'static str) -> rmcp::service::ClientInitializeError { - rmcp::service::ClientInitializeError::TransportError { - error: DynamicTransportError::from_parts( - "streamable_http", - TypeId::of::<()>(), - Box::new(StreamableHttpError::Client( - StreamableHttpClientAdapterError::RetryableHttpStatus( - reqwest::StatusCode::SERVICE_UNAVAILABLE.as_u16(), - ), - )), - ), - context: context.into(), - } -} diff --git a/justfile b/justfile index 852af2aae80..55388472104 100644 --- a/justfile +++ b/justfile @@ -1,10 +1,12 @@ set working-directory := "codex-rs" -set positional-arguments +set positional-arguments := true + export JUST_SHELL := justfile_directory() / "scripts/just-shell.py" + set shell := ["python3", "-c", 'import os, runpy; runpy.run_path(os.environ["JUST_SHELL"], run_name="__main__")'] set windows-shell := ["python", "-c", 'import os, runpy; runpy.run_path(os.environ["JUST_SHELL"], run_name="__main__")'] -rust_min_stack := "8388608" # 8 MiB +rust_min_stack := "8388608" python := if os_family() == "windows" { "python" } else { "python3" } # Display help @@ -12,7 +14,9 @@ help: just -l # `codex` + alias c := codex + codex *args: cargo run --bin codex -- {args} @@ -44,11 +48,6 @@ fmt: fmt-check: {{ python }} ../scripts/format.py --check -fmt-check: - cargo fmt -- --check --config imports_granularity=Item 2>/dev/null - uv run --frozen --project ../sdk/python --extra dev ruff check --diff ../sdk/python - uv run --frozen --project ../sdk/python --extra dev ruff format --check ../sdk/python - fix *args: cargo clippy --fix --tests --allow-dirty {args} @@ -77,6 +76,7 @@ install: # # Run `cargo install --locked cargo-nextest` if you don't have it installed. # Prefer this for routine local runs. Workspace crate features are banned, so + # there should be no need to add `--all-features`. [unix] test *args: @@ -89,6 +89,7 @@ test *args: just bench-smoke # Run from the repository root so scripts that resolve paths from `cwd` see + # the same layout they use in GitHub Actions. [no-cd] test-github-scripts: @@ -104,6 +105,7 @@ bench-smoke: # Build and run Codex from source using Bazel. # On Unix, use `[no-cd]` and `--run_under="cd $PWD &&"` to ensure Bazel runs + # the command in the current working directory. [no-cd] [unix] From 78265d6f1057a370c3d4f691c7ec28c7c62fc7cd Mon Sep 17 00:00:00 2001 From: Sama Setty Date: Thu, 4 Jun 2026 19:03:37 -0700 Subject: [PATCH 12/15] codex: fix justfile formatting for CI (#25147) --- justfile | 11 ++--------- 1 file changed, 2 insertions(+), 9 deletions(-) diff --git a/justfile b/justfile index 55388472104..fe7e7349b31 100644 --- a/justfile +++ b/justfile @@ -1,12 +1,10 @@ set working-directory := "codex-rs" -set positional-arguments := true - +set positional-arguments export JUST_SHELL := justfile_directory() / "scripts/just-shell.py" - set shell := ["python3", "-c", 'import os, runpy; runpy.run_path(os.environ["JUST_SHELL"], run_name="__main__")'] set windows-shell := ["python", "-c", 'import os, runpy; runpy.run_path(os.environ["JUST_SHELL"], run_name="__main__")'] -rust_min_stack := "8388608" +rust_min_stack := "8388608" # 8 MiB python := if os_family() == "windows" { "python" } else { "python3" } # Display help @@ -14,9 +12,7 @@ help: just -l # `codex` - alias c := codex - codex *args: cargo run --bin codex -- {args} @@ -76,7 +72,6 @@ install: # # Run `cargo install --locked cargo-nextest` if you don't have it installed. # Prefer this for routine local runs. Workspace crate features are banned, so - # there should be no need to add `--all-features`. [unix] test *args: @@ -89,7 +84,6 @@ test *args: just bench-smoke # Run from the repository root so scripts that resolve paths from `cwd` see - # the same layout they use in GitHub Actions. [no-cd] test-github-scripts: @@ -105,7 +99,6 @@ bench-smoke: # Build and run Codex from source using Bazel. # On Unix, use `[no-cd]` and `--run_under="cd $PWD &&"` to ensure Bazel runs - # the command in the current working directory. [no-cd] [unix] From 3c3374efdf078d521705cee8d4d579677f28a7c3 Mon Sep 17 00:00:00 2001 From: Sama Setty Date: Thu, 4 Jun 2026 19:16:56 -0700 Subject: [PATCH 13/15] codex: address streamable http retry review feedback (#25147) --- codex-rs/Cargo.lock | 1 - codex-rs/rmcp-client/Cargo.toml | 1 - .../rmcp-client/src/http_client_adapter.rs | 133 +---------------- .../src/http_client_adapter_tests.rs | 136 ++++++++++++++++++ .../src/rmcp_client/streamable_http_retry.rs | 103 ++----------- .../streamable_http_retry_tests.rs | 76 +++------- 6 files changed, 170 insertions(+), 280 deletions(-) create mode 100644 codex-rs/rmcp-client/src/http_client_adapter_tests.rs diff --git a/codex-rs/Cargo.lock b/codex-rs/Cargo.lock index ae9083c1844..1e79865d437 100644 --- a/codex-rs/Cargo.lock +++ b/codex-rs/Cargo.lock @@ -3573,7 +3573,6 @@ dependencies = [ "codex-config", "codex-exec-server", "codex-keyring-store", - "codex-otel", "codex-protocol", "codex-utils-cargo-bin", "codex-utils-home-dir", diff --git a/codex-rs/rmcp-client/Cargo.toml b/codex-rs/rmcp-client/Cargo.toml index 152ac34de43..e3417de70ec 100644 --- a/codex-rs/rmcp-client/Cargo.toml +++ b/codex-rs/rmcp-client/Cargo.toml @@ -20,7 +20,6 @@ codex-client = { workspace = true } codex-config = { workspace = true } codex-exec-server = { workspace = true } codex-keyring-store = { workspace = true } -codex-otel = { workspace = true } codex-protocol = { workspace = true } codex-utils-pty = { workspace = true } codex-utils-home-dir = { workspace = true } diff --git a/codex-rs/rmcp-client/src/http_client_adapter.rs b/codex-rs/rmcp-client/src/http_client_adapter.rs index df7bee3a83b..16f4ba7064c 100644 --- a/codex-rs/rmcp-client/src/http_client_adapter.rs +++ b/codex-rs/rmcp-client/src/http_client_adapter.rs @@ -513,134 +513,5 @@ fn sse_stream_from_body( } #[cfg(test)] -mod tests { - use axum::Json; - use axum::Router; - use axum::http::StatusCode; - use axum::response::IntoResponse; - use axum::routing::post; - use codex_exec_server::Environment; - use pretty_assertions::assert_eq; - use rmcp::model::ClientRequest; - use rmcp::model::ErrorData; - use rmcp::model::JsonRpcError; - use rmcp::model::PingRequest; - use rmcp::model::RequestId; - use rmcp::model::ServerJsonRpcMessage; - use rmcp::transport::streamable_http_client::StreamableHttpPostResponse; - use serde_json::json; - use tokio::net::TcpListener; - - use super::*; - - #[tokio::test] - async fn post_message_parses_json_error_body_before_retryable_status() -> anyhow::Result<()> { - let listener = TcpListener::bind("127.0.0.1:0").await?; - let address = listener.local_addr()?; - let app = Router::new().route("/", post(json_error_response)); - let server = tokio::spawn(async move { axum::serve(listener, app).await }); - - let adapter = StreamableHttpClientAdapter::new( - Environment::default_for_tests().get_http_client(), - HeaderMap::new(), - /*auth_provider*/ None, - ); - let request = ClientJsonRpcMessage::request( - ClientRequest::PingRequest(PingRequest::default()), - RequestId::Number(1), - ); - - let response = adapter - .post_message( - Arc::from(format!("http://{address}/")), - request, - /*session_id*/ None, - /*auth_token*/ None, - HashMap::new(), - ) - .await?; - - server.abort(); - - let StreamableHttpPostResponse::Json(message, _session_id) = response else { - panic!("expected JSON response"); - }; - let ServerJsonRpcMessage::Error(error) = message else { - panic!("expected JSON-RPC error"); - }; - assert_eq!( - error, - JsonRpcError::new( - /*id*/ Some(RequestId::Number(1)), - ErrorData::internal_error("transient json error", /*data*/ None), - ) - ); - - Ok(()) - } - - #[tokio::test] - async fn post_message_retries_non_json_rpc_json_error_body() -> anyhow::Result<()> { - let listener = TcpListener::bind("127.0.0.1:0").await?; - let address = listener.local_addr()?; - let app = Router::new().route("/", post(non_json_rpc_json_error_response)); - let server = tokio::spawn(async move { axum::serve(listener, app).await }); - - let adapter = StreamableHttpClientAdapter::new( - Environment::default_for_tests().get_http_client(), - HeaderMap::new(), - /*auth_provider*/ None, - ); - let request = ClientJsonRpcMessage::request( - ClientRequest::PingRequest(PingRequest::default()), - RequestId::Number(1), - ); - - let result = adapter - .post_message( - Arc::from(format!("http://{address}/")), - request, - /*session_id*/ None, - /*auth_token*/ None, - HashMap::new(), - ) - .await; - - server.abort(); - - let Err(StreamableHttpError::Client( - StreamableHttpClientAdapterError::RetryableHttpStatus(status), - )) = result - else { - panic!("expected retryable HTTP status error"); - }; - assert_eq!(status, StatusCode::SERVICE_UNAVAILABLE.as_u16()); - - Ok(()) - } - - async fn json_error_response() -> impl IntoResponse { - ( - StatusCode::INTERNAL_SERVER_ERROR, - [(CONTENT_TYPE, JSON_MIME_TYPE)], - Json(json!({ - "jsonrpc": "2.0", - "id": 1, - "error": { - "code": -32603, - "message": "transient json error", - }, - })), - ) - } - - async fn non_json_rpc_json_error_response() -> impl IntoResponse { - ( - StatusCode::SERVICE_UNAVAILABLE, - [(CONTENT_TYPE, JSON_MIME_TYPE)], - Json(json!({ - "error": "service temporarily unavailable", - })), - ) - } -} +#[path = "http_client_adapter_tests.rs"] +mod tests; diff --git a/codex-rs/rmcp-client/src/http_client_adapter_tests.rs b/codex-rs/rmcp-client/src/http_client_adapter_tests.rs new file mode 100644 index 00000000000..28e686f0c2b --- /dev/null +++ b/codex-rs/rmcp-client/src/http_client_adapter_tests.rs @@ -0,0 +1,136 @@ +use std::collections::HashMap; +use std::sync::Arc; + +use axum::Json; +use axum::Router; +use axum::http::StatusCode; +use axum::response::IntoResponse; +use axum::routing::post; +use codex_exec_server::Environment; +use pretty_assertions::assert_eq; +use reqwest::header::HeaderMap; +use rmcp::model::ClientJsonRpcMessage; +use rmcp::model::ClientRequest; +use rmcp::model::ErrorData; +use rmcp::model::JsonRpcError; +use rmcp::model::PingRequest; +use rmcp::model::RequestId; +use rmcp::model::ServerJsonRpcMessage; +use rmcp::transport::streamable_http_client::StreamableHttpClient; +use rmcp::transport::streamable_http_client::StreamableHttpError; +use rmcp::transport::streamable_http_client::StreamableHttpPostResponse; +use serde_json::json; +use tokio::net::TcpListener; + +use super::*; + +#[tokio::test] +async fn post_message_parses_json_error_body_before_retryable_status() -> anyhow::Result<()> { + let listener = TcpListener::bind("127.0.0.1:0").await?; + let address = listener.local_addr()?; + let app = Router::new().route("/", post(json_error_response)); + let server = tokio::spawn(async move { axum::serve(listener, app).await }); + + let adapter = StreamableHttpClientAdapter::new( + Environment::default_for_tests().get_http_client(), + HeaderMap::new(), + /*auth_provider*/ None, + ); + let request = ClientJsonRpcMessage::request( + ClientRequest::PingRequest(PingRequest::default()), + RequestId::Number(1), + ); + + let response = adapter + .post_message( + Arc::from(format!("http://{address}/")), + request, + /*session_id*/ None, + /*auth_token*/ None, + HashMap::new(), + ) + .await?; + + server.abort(); + + let StreamableHttpPostResponse::Json(message, _session_id) = response else { + panic!("expected JSON response"); + }; + let ServerJsonRpcMessage::Error(error) = message else { + panic!("expected JSON-RPC error"); + }; + assert_eq!( + error, + JsonRpcError::new( + /*id*/ Some(RequestId::Number(1)), + ErrorData::internal_error("transient json error", /*data*/ None), + ) + ); + + Ok(()) +} + +#[tokio::test] +async fn post_message_retries_non_json_rpc_json_error_body() -> anyhow::Result<()> { + let listener = TcpListener::bind("127.0.0.1:0").await?; + let address = listener.local_addr()?; + let app = Router::new().route("/", post(non_json_rpc_json_error_response)); + let server = tokio::spawn(async move { axum::serve(listener, app).await }); + + let adapter = StreamableHttpClientAdapter::new( + Environment::default_for_tests().get_http_client(), + HeaderMap::new(), + /*auth_provider*/ None, + ); + let request = ClientJsonRpcMessage::request( + ClientRequest::PingRequest(PingRequest::default()), + RequestId::Number(1), + ); + + let result = adapter + .post_message( + Arc::from(format!("http://{address}/")), + request, + /*session_id*/ None, + /*auth_token*/ None, + HashMap::new(), + ) + .await; + + server.abort(); + + let Err(StreamableHttpError::Client(StreamableHttpClientAdapterError::RetryableHttpStatus( + status, + ))) = result + else { + panic!("expected retryable HTTP status error"); + }; + assert_eq!(status, StatusCode::SERVICE_UNAVAILABLE.as_u16()); + + Ok(()) +} + +async fn json_error_response() -> impl IntoResponse { + ( + StatusCode::INTERNAL_SERVER_ERROR, + [(CONTENT_TYPE, JSON_MIME_TYPE)], + Json(json!({ + "jsonrpc": "2.0", + "id": 1, + "error": { + "code": -32603, + "message": "transient json error", + }, + })), + ) +} + +async fn non_json_rpc_json_error_response() -> impl IntoResponse { + ( + StatusCode::SERVICE_UNAVAILABLE, + [(CONTENT_TYPE, JSON_MIME_TYPE)], + Json(json!({ + "error": "service temporarily unavailable", + })), + ) +} diff --git a/codex-rs/rmcp-client/src/rmcp_client/streamable_http_retry.rs b/codex-rs/rmcp-client/src/rmcp_client/streamable_http_retry.rs index 42cd225d0a9..b1f3b363891 100644 --- a/codex-rs/rmcp-client/src/rmcp_client/streamable_http_retry.rs +++ b/codex-rs/rmcp-client/src/rmcp_client/streamable_http_retry.rs @@ -21,7 +21,6 @@ use super::PendingTransport; use super::RmcpClient; use super::active_time_timeout; -const MCP_CLIENT_INITIALIZE_METRIC: &str = "codex.mcp.client.initialize"; const JSON_RPC_INTERNAL_ERROR_CODE: i64 = -32603; const STREAMABLE_HTTP_RETRY_DELAYS_MS: [u64; 2] = [250, 1_000]; @@ -35,11 +34,10 @@ impl RmcpClient { Arc>, Option, )> { - let (should_retry, metric_transport) = match &initial_transport { - PendingTransport::InProcess { .. } => (false, "in_process"), - PendingTransport::Stdio { .. } => (false, "stdio"), + let should_retry = match &initial_transport { + PendingTransport::InProcess { .. } | PendingTransport::Stdio { .. } => false, PendingTransport::StreamableHttp { .. } - | PendingTransport::StreamableHttpWithOAuth { .. } => (true, "streamable_http"), + | PendingTransport::StreamableHttpWithOAuth { .. } => true, }; let retry_deadline = timeout.map(|duration| Instant::now() + duration); let mut pending_transport = Some(initial_transport); @@ -57,13 +55,6 @@ impl RmcpClient { if let Some(remaining) = attempt_timeout && remaining.is_zero() { - emit_initialize_metric( - metric_transport, - "error", - attempt_count, - /*retry_exhausted*/ false, - "timeout", - ); let duration = timeout.unwrap_or(remaining); return Err(anyhow!( "timed out handshaking with MCP server after {duration:?}" @@ -74,16 +65,7 @@ impl RmcpClient { Some(transport) => transport, None => match Self::create_pending_transport(&self.transport_recipe).await { Ok(transport) => transport, - Err(error) => { - emit_initialize_metric( - metric_transport, - "error", - attempt_count, - /*retry_exhausted*/ false, - "transport_create", - ); - return Err(error); - } + Err(error) => return Err(error), }, }; @@ -94,25 +76,9 @@ impl RmcpClient { ) .await { - Ok(result) => { - emit_initialize_metric( - metric_transport, - "success", - attempt_count, - /*retry_exhausted*/ false, - "none", - ); - return Ok(result); - } + Ok(result) => return Ok(result), Err(error) if should_retry && Self::is_retryable_initialize_error(&error) => { let Some(retry_delay_ms) = retry_delay_ms else { - emit_initialize_metric( - metric_transport, - "error", - attempt_count, - /*retry_exhausted*/ true, - "retry_exhausted", - ); return Err(error); }; let delay = Duration::from_millis(retry_delay_ms); @@ -124,29 +90,13 @@ impl RmcpClient { "streamable HTTP MCP initialize failed with a retryable error; retrying" ); if !sleep_with_retry_deadline(delay, retry_deadline).await { - emit_initialize_metric( - metric_transport, - "error", - attempt_count, - /*retry_exhausted*/ false, - "timeout", - ); let duration = timeout.unwrap_or(delay); return Err(anyhow!( "timed out handshaking with MCP server after {duration:?}" )); } } - Err(error) => { - emit_initialize_metric( - metric_transport, - "error", - attempt_count, - /*retry_exhausted*/ false, - "non_retryable", - ); - return Err(error); - } + Err(error) => return Err(error), } } @@ -335,6 +285,9 @@ impl RmcpClient { )) => { *code == JSON_RPC_INTERNAL_ERROR_CODE && message.starts_with("http/request failed:") } + StreamableHttpError::Client(StreamableHttpClientAdapterError::HttpRequest( + ExecServerError::Protocol(message), + )) => message.starts_with("http response stream `") && message.contains("` failed:"), _ => false, } } @@ -353,44 +306,6 @@ async fn sleep_with_retry_deadline(delay: Duration, deadline: Option) - } } -fn initialize_metric_tags( - transport: &'static str, - outcome: &'static str, - attempts: usize, - retry_exhausted: bool, - failure_kind: &'static str, -) -> Vec<(&'static str, String)> { - let attempts = attempts.max(1); - vec![ - ("transport", transport.to_string()), - ("outcome", outcome.to_string()), - ("retried", (attempts > 1).to_string()), - ("attempts", attempts.to_string()), - ("retry_count", attempts.saturating_sub(1).to_string()), - ("retry_exhausted", retry_exhausted.to_string()), - ("failure_kind", failure_kind.to_string()), - ] -} - -fn emit_initialize_metric( - transport: &'static str, - outcome: &'static str, - attempts: usize, - retry_exhausted: bool, - failure_kind: &'static str, -) { - let Some(metrics) = codex_otel::global() else { - return; - }; - - let tags = initialize_metric_tags(transport, outcome, attempts, retry_exhausted, failure_kind); - let tag_refs: Vec<(&str, &str)> = tags - .iter() - .map(|(key, value)| (*key, value.as_str())) - .collect(); - let _ = metrics.counter(MCP_CLIENT_INITIALIZE_METRIC, /*inc*/ 1, &tag_refs); -} - #[derive(Debug, thiserror::Error)] enum ClientOperationError { #[error(transparent)] diff --git a/codex-rs/rmcp-client/src/rmcp_client/streamable_http_retry_tests.rs b/codex-rs/rmcp-client/src/rmcp_client/streamable_http_retry_tests.rs index 9a58c27c0d9..17db426062a 100644 --- a/codex-rs/rmcp-client/src/rmcp_client/streamable_http_retry_tests.rs +++ b/codex-rs/rmcp-client/src/rmcp_client/streamable_http_retry_tests.rs @@ -1,6 +1,6 @@ use std::any::TypeId; -use std::collections::BTreeMap; +use codex_exec_server::ExecServerError; use pretty_assertions::assert_eq; use rmcp::transport::DynamicTransportError; use rmcp::transport::streamable_http_client::StreamableHttpError; @@ -9,58 +9,6 @@ use crate::http_client_adapter::StreamableHttpClientAdapterError; use super::*; -fn metric_tags_map(tags: Vec<(&'static str, String)>) -> BTreeMap<&'static str, String> { - tags.into_iter().collect() -} - -#[test] -fn initialize_metric_tags_record_success_after_retry() { - let tags = metric_tags_map(initialize_metric_tags( - "streamable_http", - "success", - /*attempts*/ 2, - /*retry_exhausted*/ false, - "none", - )); - - assert_eq!( - tags, - metric_tags_map(vec![ - ("transport", "streamable_http".to_string()), - ("outcome", "success".to_string()), - ("retried", "true".to_string()), - ("attempts", "2".to_string()), - ("retry_count", "1".to_string()), - ("retry_exhausted", "false".to_string()), - ("failure_kind", "none".to_string()), - ]) - ); -} - -#[test] -fn initialize_metric_tags_record_retry_exhaustion() { - let tags = metric_tags_map(initialize_metric_tags( - "streamable_http", - "error", - /*attempts*/ 3, - /*retry_exhausted*/ true, - "retry_exhausted", - )); - - assert_eq!( - tags, - metric_tags_map(vec![ - ("transport", "streamable_http".to_string()), - ("outcome", "error".to_string()), - ("retried", "true".to_string()), - ("attempts", "3".to_string()), - ("retry_count", "2".to_string()), - ("retry_exhausted", "true".to_string()), - ("failure_kind", "retry_exhausted".to_string()), - ]) - ); -} - #[test] fn retryable_initialize_error_includes_initialized_notification_context() { let contexts = [ @@ -77,6 +25,28 @@ fn retryable_initialize_error_includes_initialized_notification_context() { ); } +#[test] +fn retryable_streamable_http_error_includes_remote_body_stream_failure() { + let errors = [ + StreamableHttpError::Client(StreamableHttpClientAdapterError::HttpRequest( + ExecServerError::Protocol( + "http response stream `http-1` failed: exec-server transport disconnected" + .to_string(), + ), + )), + StreamableHttpError::Client(StreamableHttpClientAdapterError::HttpRequest( + ExecServerError::Protocol( + "http response stream `http-1` received seq 2, expected 1".to_string(), + ), + )), + ]; + + assert_eq!( + errors.map(|error| RmcpClient::is_retryable_streamable_http_error(&error)), + [true, false], + ); +} + fn retryable_initialize_error(context: &'static str) -> rmcp::service::ClientInitializeError { rmcp::service::ClientInitializeError::TransportError { error: DynamicTransportError::from_parts( From 2dd78920b08c096471fca1b5d311f86a9df0d03d Mon Sep 17 00:00:00 2001 From: Sama Setty Date: Thu, 4 Jun 2026 19:22:37 -0700 Subject: [PATCH 14/15] codex: keep rmcp retry module surgical (#25147) --- codex-rs/rmcp-client/src/rmcp_client.rs | 1 + .../rmcp-client/src/{rmcp_client => }/streamable_http_retry.rs | 0 .../src/{rmcp_client => }/streamable_http_retry_tests.rs | 0 3 files changed, 1 insertion(+) rename codex-rs/rmcp-client/src/{rmcp_client => }/streamable_http_retry.rs (100%) rename codex-rs/rmcp-client/src/{rmcp_client => }/streamable_http_retry_tests.rs (100%) diff --git a/codex-rs/rmcp-client/src/rmcp_client.rs b/codex-rs/rmcp-client/src/rmcp_client.rs index b2f764b9cae..74ce9baac0e 100644 --- a/codex-rs/rmcp-client/src/rmcp_client.rs +++ b/codex-rs/rmcp-client/src/rmcp_client.rs @@ -72,6 +72,7 @@ use crate::utils::apply_default_headers; use crate::utils::build_default_headers; use codex_config::types::OAuthCredentialsStoreMode; +#[path = "streamable_http_retry.rs"] mod streamable_http_retry; use self::streamable_http_retry::HandshakeError; diff --git a/codex-rs/rmcp-client/src/rmcp_client/streamable_http_retry.rs b/codex-rs/rmcp-client/src/streamable_http_retry.rs similarity index 100% rename from codex-rs/rmcp-client/src/rmcp_client/streamable_http_retry.rs rename to codex-rs/rmcp-client/src/streamable_http_retry.rs diff --git a/codex-rs/rmcp-client/src/rmcp_client/streamable_http_retry_tests.rs b/codex-rs/rmcp-client/src/streamable_http_retry_tests.rs similarity index 100% rename from codex-rs/rmcp-client/src/rmcp_client/streamable_http_retry_tests.rs rename to codex-rs/rmcp-client/src/streamable_http_retry_tests.rs From e76c6e4c933663edc133133f42ec6b8886e4b122 Mon Sep 17 00:00:00 2001 From: Sama Setty Date: Thu, 4 Jun 2026 19:37:57 -0700 Subject: [PATCH 15/15] codex: narrow streamable http initialize retry (#25147) --- .../src/bin/test_streamable_http_server.rs | 106 +++--- .../rmcp-client/src/http_client_adapter.rs | 39 +-- .../src/http_client_adapter_tests.rs | 136 -------- codex-rs/rmcp-client/src/rmcp_client.rs | 122 ++++++- codex-rs/rmcp-client/src/rmcp_client_tests.rs | 24 -- .../rmcp-client/src/streamable_http_retry.rs | 175 +--------- .../src/streamable_http_retry_tests.rs | 17 +- .../tests/streamable_http_recovery.rs | 308 ++---------------- .../tests/streamable_http_remote.rs | 248 -------------- .../tests/streamable_http_test_support.rs | 21 +- 10 files changed, 222 insertions(+), 974 deletions(-) delete mode 100644 codex-rs/rmcp-client/src/http_client_adapter_tests.rs delete mode 100644 codex-rs/rmcp-client/src/rmcp_client_tests.rs diff --git a/codex-rs/rmcp-client/src/bin/test_streamable_http_server.rs b/codex-rs/rmcp-client/src/bin/test_streamable_http_server.rs index 8004aa5dfba..2384d394736 100644 --- a/codex-rs/rmcp-client/src/bin/test_streamable_http_server.rs +++ b/codex-rs/rmcp-client/src/bin/test_streamable_http_server.rs @@ -63,13 +63,11 @@ struct TestToolServer { const MEMO_URI: &str = "memo://codex/example-note"; const MEMO_CONTENT: &str = "This is a sample MCP resource served by the rmcp test server."; const MCP_SESSION_ID_HEADER: &str = "mcp-session-id"; -const INITIALIZE_FAILURE_CONTROL_PATH: &str = "/test/control/initialize-failure"; const SESSION_POST_FAILURE_CONTROL_PATH: &str = "/test/control/session-post-failure"; #[derive(Clone, Default)] -struct FailureState { - initialize_failure: Arc>>, - session_post_failure: Arc>>, +struct SessionFailureState { + armed_failure: Arc>>, } #[derive(Clone, Debug)] @@ -81,7 +79,7 @@ struct ArmedFailure { } #[derive(Debug, Deserialize)] -struct ArmFailureRequest { +struct ArmSessionPostFailureRequest { status: u16, remaining: usize, /// Raw `WWW-Authenticate` challenge header field values to add to the failure. @@ -99,7 +97,7 @@ struct EchoArgs { #[tokio::main] async fn main() -> Result<(), Box> { let bind_addr = parse_bind_addr()?; - let failure_state = FailureState::default(); + let session_failure_state = SessionFailureState::default(); const MAX_BIND_RETRIES: u32 = 20; const BIND_RETRY_DELAY: Duration = Duration::from_millis(50); @@ -127,7 +125,6 @@ async fn main() -> Result<(), Box> { eprintln!("starting rmcp streamable http test server on http://{actual_bind_addr}/mcp"); let router = Router::new() - .route(INITIALIZE_FAILURE_CONTROL_PATH, post(arm_initialize_failure)) .route( SESSION_POST_FAILURE_CONTROL_PATH, post(arm_session_post_failure), @@ -165,10 +162,10 @@ async fn main() -> Result<(), Box> { ), ) .layer(middleware::from_fn_with_state( - failure_state.clone(), - fail_post_when_armed, + session_failure_state.clone(), + fail_session_post_when_armed, )) - .with_state(failure_state); + .with_state(session_failure_state); let router = if let Ok(token) = std::env::var("MCP_EXPECT_BEARER") { let expected = Arc::new(format!("Bearer {token}")); @@ -407,22 +404,8 @@ async fn require_bearer( } async fn arm_session_post_failure( - State(state): State, - Json(request): Json, -) -> Result { - arm_failure(&state.session_post_failure, request).await -} - -async fn arm_initialize_failure( - State(state): State, - Json(request): Json, -) -> Result { - arm_failure(&state.initialize_failure, request).await -} - -async fn arm_failure( - armed_failure: &Arc>>, - request: ArmFailureRequest, + State(state): State, + Json(request): Json, ) -> Result { let status = StatusCode::from_u16(request.status).map_err(|_| StatusCode::BAD_REQUEST)?; let www_authenticate_headers = request @@ -430,7 +413,7 @@ async fn arm_failure( .into_iter() .map(|value| HeaderValue::from_str(&value).map_err(|_| StatusCode::BAD_REQUEST)) .collect::, _>>()?; - let failure = if request.remaining == 0 { + let armed_failure = if request.remaining == 0 { None } else { Some(ArmedFailure { @@ -439,56 +422,45 @@ async fn arm_failure( www_authenticate_headers, }) }; - *armed_failure.lock().await = failure; + *state.armed_failure.lock().await = armed_failure; Ok(StatusCode::NO_CONTENT) } -async fn fail_post_when_armed( - State(state): State, +async fn fail_session_post_when_armed( + State(state): State, request: Request, next: Next, ) -> Response { - if request.uri().path() != "/mcp" || request.method() != Method::POST { + if request.uri().path() != "/mcp" + || request.method() != Method::POST + || !request.headers().contains_key(MCP_SESSION_ID_HEADER) + { return next.run(request).await; } - let (armed_failure, label) = if request.headers().contains_key(MCP_SESSION_ID_HEADER) { - (&state.session_post_failure, "session") - } else { - (&state.initialize_failure, "initialize") - }; - - if let Some(response) = consume_failure(armed_failure, label).await { - return response; + { + let mut armed_failure = state.armed_failure.lock().await; + if let Some(failure) = armed_failure.as_mut() + && failure.remaining > 0 + { + failure.remaining -= 1; + let status = failure.status; + let www_authenticate_headers = failure.www_authenticate_headers.clone(); + if failure.remaining == 0 { + *armed_failure = None; + } + let mut response = Response::new(Body::from(format!( + "forced session failure with status {status}" + ))); + *response.status_mut() = status; + for www_authenticate_header in www_authenticate_headers { + response + .headers_mut() + .append(WWW_AUTHENTICATE, www_authenticate_header); + } + return response; + } } next.run(request).await } - -async fn consume_failure( - armed_failure: &Arc>>, - label: &str, -) -> Option { - let mut armed_failure = armed_failure.lock().await; - let failure = armed_failure.as_mut()?; - if failure.remaining == 0 { - return None; - } - - failure.remaining -= 1; - let status = failure.status; - let www_authenticate_headers = failure.www_authenticate_headers.clone(); - if failure.remaining == 0 { - *armed_failure = None; - } - let mut response = Response::new(Body::from(format!( - "forced {label} failure with status {status}" - ))); - *response.status_mut() = status; - for www_authenticate_header in www_authenticate_headers { - response - .headers_mut() - .append(WWW_AUTHENTICATE, www_authenticate_header); - } - Some(response) -} diff --git a/codex-rs/rmcp-client/src/http_client_adapter.rs b/codex-rs/rmcp-client/src/http_client_adapter.rs index 16f4ba7064c..6f98f789205 100644 --- a/codex-rs/rmcp-client/src/http_client_adapter.rs +++ b/codex-rs/rmcp-client/src/http_client_adapter.rs @@ -29,6 +29,7 @@ use reqwest::header::HeaderMap; use reqwest::header::HeaderName; use rmcp::model::ClientJsonRpcMessage; use rmcp::model::JsonRpcMessage; +use rmcp::model::ServerJsonRpcMessage; use rmcp::transport::streamable_http_client::AuthRequiredError; use rmcp::transport::streamable_http_client::InsufficientScopeError; use rmcp::transport::streamable_http_client::StreamableHttpClient; @@ -57,8 +58,6 @@ pub(crate) struct StreamableHttpClientAdapter { pub(crate) enum StreamableHttpClientAdapterError { #[error("streamable HTTP session expired with 404 Not Found")] SessionExpired404, - #[error("streamable HTTP request returned retryable HTTP {0}")] - RetryableHttpStatus(u16), #[error(transparent)] HttpRequest(#[from] ExecServerError), #[error("invalid HTTP header: {0}")] @@ -186,33 +185,17 @@ impl StreamableHttpClient for StreamableHttpClientAdapter { let content_type = response_header(&response.headers, CONTENT_TYPE); let session_id = response_header(&response.headers, HEADER_SESSION_ID); - if let Some(content_type) = content_type.as_deref() - && content_type.starts_with(JSON_MIME_TYPE) - { - let body = collect_body(&mut body_stream).await?; - let message = match serde_json::from_slice(&body) { - Ok(message) => message, - Err(_error) if is_retryable_http_status(response.status) => { - return Err(StreamableHttpError::Client( - StreamableHttpClientAdapterError::RetryableHttpStatus(response.status), - )); - } - Err(error) => return Err(StreamableHttpError::Deserialize(error)), - }; - return Ok(StreamableHttpPostResponse::Json(message, session_id)); - } - - if is_retryable_http_status(response.status) { - return Err(StreamableHttpError::Client( - StreamableHttpClientAdapterError::RetryableHttpStatus(response.status), - )); - } - match content_type.as_deref() { Some(content_type) if content_type.starts_with(EVENT_STREAM_MIME_TYPE) => { let event_stream = sse_stream_from_body(body_stream); Ok(StreamableHttpPostResponse::Sse(event_stream, session_id)) } + Some(content_type) if content_type.starts_with(JSON_MIME_TYPE) => { + let body = collect_body(&mut body_stream).await?; + let message: ServerJsonRpcMessage = + serde_json::from_slice(&body).map_err(StreamableHttpError::Deserialize)?; + Ok(StreamableHttpPostResponse::Json(message, session_id)) + } _ => { let body = collect_body(&mut body_stream).await?; let content_type = content_type.unwrap_or_else(|| "missing-content-type".into()); @@ -480,10 +463,6 @@ fn status_is_success(status: u16) -> bool { StatusCode::from_u16(status).is_ok_and(|status| status.is_success()) } -fn is_retryable_http_status(status: u16) -> bool { - matches!(status, 408 | 429 | 500 | 502 | 503 | 504) -} - async fn collect_body( body_stream: &mut HttpResponseBodyStream, ) -> std::result::Result, StreamableHttpError> { @@ -511,7 +490,3 @@ fn sse_stream_from_body( })) .boxed() } - -#[cfg(test)] -#[path = "http_client_adapter_tests.rs"] -mod tests; diff --git a/codex-rs/rmcp-client/src/http_client_adapter_tests.rs b/codex-rs/rmcp-client/src/http_client_adapter_tests.rs deleted file mode 100644 index 28e686f0c2b..00000000000 --- a/codex-rs/rmcp-client/src/http_client_adapter_tests.rs +++ /dev/null @@ -1,136 +0,0 @@ -use std::collections::HashMap; -use std::sync::Arc; - -use axum::Json; -use axum::Router; -use axum::http::StatusCode; -use axum::response::IntoResponse; -use axum::routing::post; -use codex_exec_server::Environment; -use pretty_assertions::assert_eq; -use reqwest::header::HeaderMap; -use rmcp::model::ClientJsonRpcMessage; -use rmcp::model::ClientRequest; -use rmcp::model::ErrorData; -use rmcp::model::JsonRpcError; -use rmcp::model::PingRequest; -use rmcp::model::RequestId; -use rmcp::model::ServerJsonRpcMessage; -use rmcp::transport::streamable_http_client::StreamableHttpClient; -use rmcp::transport::streamable_http_client::StreamableHttpError; -use rmcp::transport::streamable_http_client::StreamableHttpPostResponse; -use serde_json::json; -use tokio::net::TcpListener; - -use super::*; - -#[tokio::test] -async fn post_message_parses_json_error_body_before_retryable_status() -> anyhow::Result<()> { - let listener = TcpListener::bind("127.0.0.1:0").await?; - let address = listener.local_addr()?; - let app = Router::new().route("/", post(json_error_response)); - let server = tokio::spawn(async move { axum::serve(listener, app).await }); - - let adapter = StreamableHttpClientAdapter::new( - Environment::default_for_tests().get_http_client(), - HeaderMap::new(), - /*auth_provider*/ None, - ); - let request = ClientJsonRpcMessage::request( - ClientRequest::PingRequest(PingRequest::default()), - RequestId::Number(1), - ); - - let response = adapter - .post_message( - Arc::from(format!("http://{address}/")), - request, - /*session_id*/ None, - /*auth_token*/ None, - HashMap::new(), - ) - .await?; - - server.abort(); - - let StreamableHttpPostResponse::Json(message, _session_id) = response else { - panic!("expected JSON response"); - }; - let ServerJsonRpcMessage::Error(error) = message else { - panic!("expected JSON-RPC error"); - }; - assert_eq!( - error, - JsonRpcError::new( - /*id*/ Some(RequestId::Number(1)), - ErrorData::internal_error("transient json error", /*data*/ None), - ) - ); - - Ok(()) -} - -#[tokio::test] -async fn post_message_retries_non_json_rpc_json_error_body() -> anyhow::Result<()> { - let listener = TcpListener::bind("127.0.0.1:0").await?; - let address = listener.local_addr()?; - let app = Router::new().route("/", post(non_json_rpc_json_error_response)); - let server = tokio::spawn(async move { axum::serve(listener, app).await }); - - let adapter = StreamableHttpClientAdapter::new( - Environment::default_for_tests().get_http_client(), - HeaderMap::new(), - /*auth_provider*/ None, - ); - let request = ClientJsonRpcMessage::request( - ClientRequest::PingRequest(PingRequest::default()), - RequestId::Number(1), - ); - - let result = adapter - .post_message( - Arc::from(format!("http://{address}/")), - request, - /*session_id*/ None, - /*auth_token*/ None, - HashMap::new(), - ) - .await; - - server.abort(); - - let Err(StreamableHttpError::Client(StreamableHttpClientAdapterError::RetryableHttpStatus( - status, - ))) = result - else { - panic!("expected retryable HTTP status error"); - }; - assert_eq!(status, StatusCode::SERVICE_UNAVAILABLE.as_u16()); - - Ok(()) -} - -async fn json_error_response() -> impl IntoResponse { - ( - StatusCode::INTERNAL_SERVER_ERROR, - [(CONTENT_TYPE, JSON_MIME_TYPE)], - Json(json!({ - "jsonrpc": "2.0", - "id": 1, - "error": { - "code": -32603, - "message": "transient json error", - }, - })), - ) -} - -async fn non_json_rpc_json_error_response() -> impl IntoResponse { - ( - StatusCode::SERVICE_UNAVAILABLE, - [(CONTENT_TYPE, JSON_MIME_TYPE)], - Json(json!({ - "error": "service temporarily unavailable", - })), - ) -} diff --git a/codex-rs/rmcp-client/src/rmcp_client.rs b/codex-rs/rmcp-client/src/rmcp_client.rs index 74ce9baac0e..753b27a3f06 100644 --- a/codex-rs/rmcp-client/src/rmcp_client.rs +++ b/codex-rs/rmcp-client/src/rmcp_client.rs @@ -49,6 +49,7 @@ use rmcp::transport::auth::AuthClient; use rmcp::transport::auth::AuthError; use rmcp::transport::auth::OAuthState; use rmcp::transport::streamable_http_client::StreamableHttpClientTransportConfig; +use rmcp::transport::streamable_http_client::StreamableHttpError; use serde::Deserialize; use serde::Serialize; use serde_json::Value; @@ -60,6 +61,7 @@ use tracing::warn; use crate::elicitation_client_service::ElicitationClientService; use crate::http_client_adapter::StreamableHttpClientAdapter; +use crate::http_client_adapter::StreamableHttpClientAdapterError; use crate::in_process_transport::InProcessTransportFactory; use crate::load_oauth_tokens; use crate::oauth::OAuthPersistor; @@ -218,6 +220,14 @@ where } } +#[derive(Debug, thiserror::Error)] +enum ClientOperationError { + #[error(transparent)] + Service(#[from] rmcp::service::ServiceError), + #[error("timed out awaiting {label} after {duration:?}")] + Timeout { label: String, duration: Duration }, +} + pub type Elicitation = CreateElicitationRequestParams; #[derive(Debug, Clone, PartialEq, Serialize, Deserialize)] @@ -857,6 +867,89 @@ impl RmcpClient { Ok((Arc::new(service), oauth_persistor)) } + async fn run_service_operation( + &self, + label: &str, + timeout: Option, + operation: F, + ) -> Result + where + F: Fn(Arc>) -> Fut, + Fut: std::future::Future>, + { + let service = self.service().await?; + match Self::run_service_operation_once( + Arc::clone(&service), + label, + timeout, + self.elicitation_pause_state.clone(), + &operation, + ) + .await + { + Ok(result) => Ok(result), + Err(error) if Self::is_session_expired_404(&error) => { + self.reinitialize_after_session_expiry(&service).await?; + let recovered_service = self.service().await?; + Self::run_service_operation_once( + recovered_service, + label, + timeout, + self.elicitation_pause_state.clone(), + &operation, + ) + .await + .map_err(Into::into) + } + Err(error) => Err(error.into()), + } + } + + async fn run_service_operation_once( + service: Arc>, + label: &str, + timeout: Option, + pause_state: ElicitationPauseState, + operation: &F, + ) -> std::result::Result + where + F: Fn(Arc>) -> Fut, + Fut: std::future::Future>, + { + match timeout { + Some(duration) => { + active_time_timeout(duration, pause_state.subscribe(), operation(service)) + .await + .map_err(|_| ClientOperationError::Timeout { + label: label.to_string(), + duration, + })? + .map_err(ClientOperationError::from) + } + None => operation(service).await.map_err(ClientOperationError::from), + } + } + + fn is_session_expired_404(error: &ClientOperationError) -> bool { + let ClientOperationError::Service(rmcp::service::ServiceError::TransportSend(error)) = + error + else { + return false; + }; + + error + .error + .downcast_ref::>() + .is_some_and(|error| { + matches!( + error, + StreamableHttpError::Client( + StreamableHttpClientAdapterError::SessionExpired404 + ) + ) + }) + } + async fn reinitialize_after_session_expiry( &self, failed_service: &Arc>, @@ -978,5 +1071,30 @@ async fn create_oauth_transport_and_runtime( } #[cfg(test)] -#[path = "rmcp_client_tests.rs"] -mod tests; +mod tests { + use std::time::Duration; + + use pretty_assertions::assert_eq; + use tokio::time; + + use super::*; + + #[tokio::test] + async fn active_time_timeout_pauses_while_elicitation_is_pending() { + let pause_state = ElicitationPauseState::new(); + let pause = pause_state.enter(); + tokio::spawn(async move { + time::sleep(Duration::from_millis(75)).await; + drop(pause); + }); + + let result = + active_time_timeout(Duration::from_millis(50), pause_state.subscribe(), async { + time::sleep(Duration::from_millis(90)).await; + "done" + }) + .await; + + assert_eq!(Ok("done"), result); + } +} diff --git a/codex-rs/rmcp-client/src/rmcp_client_tests.rs b/codex-rs/rmcp-client/src/rmcp_client_tests.rs deleted file mode 100644 index f371acde0de..00000000000 --- a/codex-rs/rmcp-client/src/rmcp_client_tests.rs +++ /dev/null @@ -1,24 +0,0 @@ -use std::time::Duration; - -use pretty_assertions::assert_eq; -use tokio::time; - -use super::*; - -#[tokio::test] -async fn active_time_timeout_pauses_while_elicitation_is_pending() { - let pause_state = ElicitationPauseState::new(); - let pause = pause_state.enter(); - tokio::spawn(async move { - time::sleep(Duration::from_millis(75)).await; - drop(pause); - }); - - let result = active_time_timeout(Duration::from_millis(50), pause_state.subscribe(), async { - time::sleep(Duration::from_millis(90)).await; - "done" - }) - .await; - - assert_eq!(Ok("done"), result); -} diff --git a/codex-rs/rmcp-client/src/streamable_http_retry.rs b/codex-rs/rmcp-client/src/streamable_http_retry.rs index b1f3b363891..ebca075cb8a 100644 --- a/codex-rs/rmcp-client/src/streamable_http_retry.rs +++ b/codex-rs/rmcp-client/src/streamable_http_retry.rs @@ -1,4 +1,3 @@ -use std::future::Future; use std::sync::Arc; use std::time::Duration; use std::time::Instant; @@ -16,10 +15,8 @@ use crate::elicitation_client_service::ElicitationClientService; use crate::http_client_adapter::StreamableHttpClientAdapterError; use crate::oauth::OAuthPersistor; -use super::ElicitationPauseState; use super::PendingTransport; use super::RmcpClient; -use super::active_time_timeout; const JSON_RPC_INTERNAL_ERROR_CODE: i64 = -32603; const STREAMABLE_HTTP_RETRY_DELAYS_MS: [u64; 2] = [250, 1_000]; @@ -42,14 +39,13 @@ impl RmcpClient { let retry_deadline = timeout.map(|duration| Instant::now() + duration); let mut pending_transport = Some(initial_transport); - let retry_schedule = STREAMABLE_HTTP_RETRY_DELAYS_MS + for (attempt, retry_delay_ms) in STREAMABLE_HTTP_RETRY_DELAYS_MS .iter() .copied() .map(Some) - .chain(std::iter::once(None)); - - for (attempt, retry_delay_ms) in retry_schedule.enumerate() { - let attempt_count = attempt + 1; + .chain(std::iter::once(None)) + .enumerate() + { let attempt_timeout = retry_deadline.map(|deadline| deadline.saturating_duration_since(Instant::now())); if let Some(remaining) = attempt_timeout @@ -63,10 +59,7 @@ impl RmcpClient { let transport = match pending_transport.take() { Some(transport) => transport, - None => match Self::create_pending_transport(&self.transport_recipe).await { - Ok(transport) => transport, - Err(error) => return Err(error), - }, + None => Self::create_pending_transport(&self.transport_recipe).await?, }; match Self::connect_pending_transport( @@ -83,7 +76,7 @@ impl RmcpClient { }; let delay = Duration::from_millis(retry_delay_ms); warn!( - attempt = attempt_count, + attempt = attempt + 1, max_attempts = STREAMABLE_HTTP_RETRY_DELAYS_MS.len() + 1, delay_ms = delay.as_millis(), error = %error, @@ -103,147 +96,6 @@ impl RmcpClient { unreachable!("initialize retry loop should return on success or final error") } - pub(super) async fn run_service_operation( - &self, - label: &str, - timeout: Option, - operation: F, - ) -> Result - where - F: Fn(Arc>) -> Fut, - Fut: Future>, - { - let mut session_recovery_attempted = false; - let mut retry_attempt = 0; - let retry_deadline = timeout.map(|duration| Instant::now() + duration); - - loop { - let service = self.service().await?; - let attempt_timeout = - retry_deadline.map(|deadline| deadline.saturating_duration_since(Instant::now())); - if let Some(remaining) = attempt_timeout - && remaining.is_zero() - { - let duration = timeout.unwrap_or(remaining); - return Err(ClientOperationError::Timeout { - label: label.to_string(), - duration, - } - .into()); - } - - match Self::run_service_operation_once( - Arc::clone(&service), - label, - attempt_timeout, - self.elicitation_pause_state.clone(), - &operation, - ) - .await - { - Ok(result) => return Ok(result), - Err(error) - if !session_recovery_attempted && Self::is_session_expired_404(&error) => - { - session_recovery_attempted = true; - self.reinitialize_after_session_expiry(&service).await?; - } - Err(error) - if Self::should_retry_tools_list_operation(label, retry_attempt, &error) => - { - let delay = - Duration::from_millis(STREAMABLE_HTTP_RETRY_DELAYS_MS[retry_attempt]); - retry_attempt += 1; - warn!( - label, - attempt = retry_attempt, - max_attempts = STREAMABLE_HTTP_RETRY_DELAYS_MS.len() + 1, - delay_ms = delay.as_millis(), - error = %error, - "MCP service operation failed with a retryable error; retrying" - ); - if !sleep_with_retry_deadline(delay, retry_deadline).await { - let duration = timeout.unwrap_or(delay); - return Err(ClientOperationError::Timeout { - label: label.to_string(), - duration, - } - .into()); - } - } - Err(error) => return Err(error.into()), - } - } - } - - async fn run_service_operation_once( - service: Arc>, - label: &str, - timeout: Option, - pause_state: ElicitationPauseState, - operation: &F, - ) -> std::result::Result - where - F: Fn(Arc>) -> Fut, - Fut: Future>, - { - match timeout { - Some(duration) => { - active_time_timeout(duration, pause_state.subscribe(), operation(service)) - .await - .map_err(|_| ClientOperationError::Timeout { - label: label.to_string(), - duration, - })? - .map_err(ClientOperationError::from) - } - None => operation(service).await.map_err(ClientOperationError::from), - } - } - - fn is_session_expired_404(error: &ClientOperationError) -> bool { - let ClientOperationError::Service(rmcp::service::ServiceError::TransportSend(error)) = - error - else { - return false; - }; - - error - .error - .downcast_ref::>() - .is_some_and(|error| { - matches!( - error, - StreamableHttpError::Client( - StreamableHttpClientAdapterError::SessionExpired404 - ) - ) - }) - } - - fn should_retry_tools_list_operation( - label: &str, - retry_attempt: usize, - error: &ClientOperationError, - ) -> bool { - label == "tools/list" - && retry_attempt < STREAMABLE_HTTP_RETRY_DELAYS_MS.len() - && Self::is_retryable_service_operation_error(error) - } - - fn is_retryable_service_operation_error(error: &ClientOperationError) -> bool { - let ClientOperationError::Service(rmcp::service::ServiceError::TransportSend(error)) = - error - else { - return false; - }; - - error - .error - .downcast_ref::>() - .is_some_and(Self::is_retryable_streamable_http_error) - } - fn is_retryable_initialize_error(error: &anyhow::Error) -> bool { error.chain().any(|source| { source @@ -276,10 +128,9 @@ impl RmcpClient { error: &StreamableHttpError, ) -> bool { match error { - StreamableHttpError::Client( - StreamableHttpClientAdapterError::RetryableHttpStatus(_) - | StreamableHttpClientAdapterError::HttpRequest(ExecServerError::HttpRequest(_)), - ) => true, + StreamableHttpError::Client(StreamableHttpClientAdapterError::HttpRequest( + ExecServerError::HttpRequest(_), + )) => true, StreamableHttpError::Client(StreamableHttpClientAdapterError::HttpRequest( ExecServerError::Server { code, message }, )) => { @@ -306,14 +157,6 @@ async fn sleep_with_retry_deadline(delay: Duration, deadline: Option) - } } -#[derive(Debug, thiserror::Error)] -enum ClientOperationError { - #[error(transparent)] - Service(#[from] rmcp::service::ServiceError), - #[error("timed out awaiting {label} after {duration:?}")] - Timeout { label: String, duration: Duration }, -} - #[derive(Debug, thiserror::Error)] #[error("handshaking with MCP server failed: {source}")] pub(super) struct HandshakeError { diff --git a/codex-rs/rmcp-client/src/streamable_http_retry_tests.rs b/codex-rs/rmcp-client/src/streamable_http_retry_tests.rs index 17db426062a..8d2e6ecabf1 100644 --- a/codex-rs/rmcp-client/src/streamable_http_retry_tests.rs +++ b/codex-rs/rmcp-client/src/streamable_http_retry_tests.rs @@ -28,6 +28,15 @@ fn retryable_initialize_error_includes_initialized_notification_context() { #[test] fn retryable_streamable_http_error_includes_remote_body_stream_failure() { let errors = [ + StreamableHttpError::Client(StreamableHttpClientAdapterError::HttpRequest( + ExecServerError::HttpRequest("error sending request for url".to_string()), + )), + StreamableHttpError::Client(StreamableHttpClientAdapterError::HttpRequest( + ExecServerError::Server { + code: JSON_RPC_INTERNAL_ERROR_CODE, + message: "http/request failed: error sending request for url".to_string(), + }, + )), StreamableHttpError::Client(StreamableHttpClientAdapterError::HttpRequest( ExecServerError::Protocol( "http response stream `http-1` failed: exec-server transport disconnected" @@ -43,7 +52,7 @@ fn retryable_streamable_http_error_includes_remote_body_stream_failure() { assert_eq!( errors.map(|error| RmcpClient::is_retryable_streamable_http_error(&error)), - [true, false], + [true, true, true, false], ); } @@ -53,9 +62,9 @@ fn retryable_initialize_error(context: &'static str) -> rmcp::service::ClientIni "streamable_http", TypeId::of::<()>(), Box::new(StreamableHttpError::Client( - StreamableHttpClientAdapterError::RetryableHttpStatus( - reqwest::StatusCode::SERVICE_UNAVAILABLE.as_u16(), - ), + StreamableHttpClientAdapterError::HttpRequest(ExecServerError::HttpRequest( + "error sending request for url".to_string(), + )), )), ), context: context.into(), diff --git a/codex-rs/rmcp-client/tests/streamable_http_recovery.rs b/codex-rs/rmcp-client/tests/streamable_http_recovery.rs index 265efa671db..571d9adb8df 100644 --- a/codex-rs/rmcp-client/tests/streamable_http_recovery.rs +++ b/codex-rs/rmcp-client/tests/streamable_http_recovery.rs @@ -3,70 +3,51 @@ mod streamable_http_test_support; use std::sync::Arc; use std::sync::atomic::AtomicUsize; use std::sync::atomic::Ordering; -use std::time::Duration; -use std::time::Instant; -use codex_config::types::OAuthCredentialsStoreMode; use codex_exec_server::Environment; use codex_exec_server::ExecServerError; use codex_exec_server::HttpClient; use codex_exec_server::HttpRequestParams; use codex_exec_server::HttpRequestResponse; use codex_exec_server::HttpResponseBodyStream; -use codex_rmcp_client::ElicitationAction; -use codex_rmcp_client::ElicitationResponse; -use codex_rmcp_client::RmcpClient; use futures::FutureExt as _; use futures::future::BoxFuture; use pretty_assertions::assert_eq; use serde_json::Value; -use serde_json::json; -use streamable_http_test_support::arm_initialize_failure; use streamable_http_test_support::arm_session_post_failure; use streamable_http_test_support::call_echo_tool; use streamable_http_test_support::create_client; use streamable_http_test_support::create_client_with_http_client; use streamable_http_test_support::expected_echo_result; -use streamable_http_test_support::init_params; use streamable_http_test_support::spawn_streamable_http_server; const JSON_RPC_INTERNAL_ERROR_CODE: i64 = -32603; const SIMULATED_NO_RESPONSE_MESSAGE: &str = "http/request failed: error sending request for url (simulated no response)"; -#[derive(Clone, Copy)] -enum RequestFailure { - LocalHttpRequest, - RemoteServer, -} - #[derive(Clone)] -struct FailFirstMethodHttpClient { +struct FailFirstInitializeHttpClient { inner: Arc, - method: &'static str, - failure: RequestFailure, failures_remaining: Arc, - matching_post_attempts: Arc, + initialize_attempts: Arc, } -impl FailFirstMethodHttpClient { - fn new(inner: Arc, method: &'static str, failure: RequestFailure) -> Self { +impl FailFirstInitializeHttpClient { + fn new(inner: Arc) -> Self { Self { inner, - method, - failure, failures_remaining: Arc::new(AtomicUsize::new(1)), - matching_post_attempts: Arc::new(AtomicUsize::new(0)), + initialize_attempts: Arc::new(AtomicUsize::new(0)), } } - fn matching_post_attempts(&self) -> usize { - self.matching_post_attempts.load(Ordering::SeqCst) + fn initialize_attempts(&self) -> usize { + self.initialize_attempts.load(Ordering::SeqCst) } } -impl HttpClient for FailFirstMethodHttpClient { +impl HttpClient for FailFirstInitializeHttpClient { fn http_request( &self, params: HttpRequestParams, @@ -79,28 +60,16 @@ impl HttpClient for FailFirstMethodHttpClient { params: HttpRequestParams, ) -> BoxFuture<'_, Result<(HttpRequestResponse, HttpResponseBodyStream), ExecServerError>> { let inner = Arc::clone(&self.inner); - let method = self.method; - let failure = self.failure; let failures_remaining = Arc::clone(&self.failures_remaining); - let matching_post_attempts = Arc::clone(&self.matching_post_attempts); + let initialize_attempts = Arc::clone(&self.initialize_attempts); async move { - if is_json_rpc_method(¶ms, method) { - matching_post_attempts.fetch_add(1, Ordering::SeqCst); - if failures_remaining - .fetch_update(Ordering::SeqCst, Ordering::SeqCst, |remaining| { - remaining.checked_sub(1) - }) - .is_ok() - { - return Err(match failure { - RequestFailure::LocalHttpRequest => { - ExecServerError::HttpRequest(SIMULATED_NO_RESPONSE_MESSAGE.to_string()) - } - RequestFailure::RemoteServer => ExecServerError::Server { - code: JSON_RPC_INTERNAL_ERROR_CODE, - message: SIMULATED_NO_RESPONSE_MESSAGE.to_string(), - }, + if is_initialize_post(¶ms) { + initialize_attempts.fetch_add(1, Ordering::SeqCst); + if failures_remaining.swap(0, Ordering::SeqCst) > 0 { + return Err(ExecServerError::Server { + code: JSON_RPC_INTERNAL_ERROR_CODE, + message: SIMULATED_NO_RESPONSE_MESSAGE.to_string(), }); } } @@ -111,242 +80,31 @@ impl HttpClient for FailFirstMethodHttpClient { } } -fn is_json_rpc_method(params: &HttpRequestParams, method: &str) -> bool { - if !params.method.eq_ignore_ascii_case("POST") { - return false; - } - - params - .body - .as_ref() - .and_then(|body| serde_json::from_slice::(&body.0).ok()) - .and_then(|body| { - body.get("method") - .and_then(Value::as_str) - .map(str::to_string) - }) - .is_some_and(|request_method| request_method == method) -} - -#[tokio::test(flavor = "multi_thread", worker_threads = 1)] -async fn streamable_http_initialize_retries_retryable_status() -> anyhow::Result<()> { - let (_server, base_url) = spawn_streamable_http_server().await?; - - arm_initialize_failure(&base_url, /*status*/ 503, /*remaining*/ 1).await?; - - let client = create_client(&base_url).await?; - let result = call_echo_tool(&client, "after-init-retry").await?; - assert_eq!(result, expected_echo_result("after-init-retry")); - - Ok(()) -} - -#[tokio::test(flavor = "multi_thread", worker_threads = 1)] -async fn streamable_http_initialize_retry_sleep_respects_startup_timeout() -> anyhow::Result<()> { - let (_server, base_url) = spawn_streamable_http_server().await?; - arm_initialize_failure(&base_url, /*status*/ 503, /*remaining*/ 1).await?; - - let client = RmcpClient::new_streamable_http_client( - "test-streamable-http", - &format!("{base_url}/mcp"), - Some("test-bearer".to_string()), - /*http_headers*/ None, - /*env_http_headers*/ None, - OAuthCredentialsStoreMode::File, - Environment::default_for_tests().get_http_client(), - /*auth_provider*/ None, - ) - .await?; - - let started = Instant::now(); - let error = client - .initialize( - init_params(), - Some(Duration::from_millis(100)), - Box::new(|_, _| { - async { - Ok(ElicitationResponse { - action: ElicitationAction::Accept, - content: Some(json!({})), - meta: None, - }) - } - .boxed() - }), - ) - .await - .unwrap_err(); - - let elapsed = started.elapsed(); - assert!( - elapsed < Duration::from_millis(500), - "initialize retry exceeded startup timeout budget: {elapsed:?}" - ); - assert!( - error - .to_string() - .contains("timed out handshaking with MCP server"), - "expected handshake timeout, got: {error:#}" - ); - - Ok(()) -} - -#[tokio::test(flavor = "multi_thread", worker_threads = 1)] -async fn streamable_http_initialize_retries_http_request_error() -> anyhow::Result<()> { - let (_server, base_url) = spawn_streamable_http_server().await?; - let http_client = FailFirstMethodHttpClient::new( - Environment::default_for_tests().get_http_client(), - "initialize", - RequestFailure::LocalHttpRequest, - ); - - let client = create_client_with_http_client(&base_url, Arc::new(http_client.clone())).await?; - let result = call_echo_tool(&client, "after-no-response-retry").await?; - - assert_eq!(http_client.matching_post_attempts(), 2); - assert_eq!(result, expected_echo_result("after-no-response-retry")); - - Ok(()) -} - -#[tokio::test(flavor = "multi_thread", worker_threads = 1)] -async fn streamable_http_initialize_retries_remote_http_request_error() -> anyhow::Result<()> { - let (_server, base_url) = spawn_streamable_http_server().await?; - let http_client = FailFirstMethodHttpClient::new( - Environment::default_for_tests().get_http_client(), - "initialize", - RequestFailure::RemoteServer, - ); - - let client = create_client_with_http_client(&base_url, Arc::new(http_client.clone())).await?; - let result = call_echo_tool(&client, "after-remote-no-response-retry").await?; - - assert_eq!(http_client.matching_post_attempts(), 2); - assert_eq!( - result, - expected_echo_result("after-remote-no-response-retry") - ); - - Ok(()) -} - -#[tokio::test(flavor = "multi_thread", worker_threads = 1)] -async fn streamable_http_tools_list_retries_retryable_status() -> anyhow::Result<()> { - let (_server, base_url) = spawn_streamable_http_server().await?; - let client = create_client(&base_url).await?; - let expected_tools = client - .list_tools(/*params*/ None, Some(Duration::from_secs(5))) - .await?; - - arm_session_post_failure( - &base_url, - /*status*/ 503, - /*remaining*/ 1, - /*www_authenticate_headers*/ &[], - ) - .await?; - - let tools = client - .list_tools(/*params*/ None, Some(Duration::from_secs(5))) - .await?; - - assert_eq!(tools, expected_tools); - - Ok(()) +fn is_initialize_post(params: &HttpRequestParams) -> bool { + params.method.eq_ignore_ascii_case("POST") + && params + .body + .as_ref() + .and_then(|body| serde_json::from_slice::(&body.0).ok()) + .and_then(|body| { + body.get("method") + .and_then(Value::as_str) + .map(|method| method == "initialize") + }) + .unwrap_or(false) } #[tokio::test(flavor = "multi_thread", worker_threads = 1)] -async fn streamable_http_tools_list_retry_sleep_respects_operation_timeout() -> anyhow::Result<()> { +async fn streamable_http_initialize_retries_remote_no_response_error() -> anyhow::Result<()> { let (_server, base_url) = spawn_streamable_http_server().await?; - let client = create_client(&base_url).await?; + let http_client = + FailFirstInitializeHttpClient::new(Environment::default_for_tests().get_http_client()); - arm_session_post_failure( - &base_url, - /*status*/ 503, - /*remaining*/ 1, - /*www_authenticate_headers*/ &[], - ) - .await?; - - let started = Instant::now(); - let error = client - .list_tools(/*params*/ None, Some(Duration::from_millis(100))) - .await - .unwrap_err(); - - let elapsed = started.elapsed(); - assert!( - elapsed < Duration::from_millis(500), - "tools/list retry exceeded operation timeout budget: {elapsed:?}" - ); - assert!( - error.to_string().contains("timed out awaiting tools/list"), - "expected tools/list timeout, got: {error:#}" - ); - - Ok(()) -} - -#[tokio::test(flavor = "multi_thread", worker_threads = 1)] -async fn streamable_http_tools_list_retries_http_request_error() -> anyhow::Result<()> { - let (_server, base_url) = spawn_streamable_http_server().await?; - let baseline_client = create_client(&base_url).await?; - let expected_tools = baseline_client - .list_tools(/*params*/ None, Some(Duration::from_secs(5))) - .await?; - let http_client = FailFirstMethodHttpClient::new( - Environment::default_for_tests().get_http_client(), - "tools/list", - RequestFailure::LocalHttpRequest, - ); - let client = create_client_with_http_client(&base_url, Arc::new(http_client.clone())).await?; - - let tools = client - .list_tools(/*params*/ None, Some(Duration::from_secs(5))) - .await?; - - assert_eq!(http_client.matching_post_attempts(), 2); - assert_eq!(tools, expected_tools); - - Ok(()) -} - -#[tokio::test(flavor = "multi_thread", worker_threads = 1)] -async fn streamable_http_tools_list_retries_remote_http_request_error() -> anyhow::Result<()> { - let (_server, base_url) = spawn_streamable_http_server().await?; - let baseline_client = create_client(&base_url).await?; - let expected_tools = baseline_client - .list_tools(/*params*/ None, Some(Duration::from_secs(5))) - .await?; - let http_client = FailFirstMethodHttpClient::new( - Environment::default_for_tests().get_http_client(), - "tools/list", - RequestFailure::RemoteServer, - ); let client = create_client_with_http_client(&base_url, Arc::new(http_client.clone())).await?; + let result = call_echo_tool(&client, "after-init-retry").await?; - let tools = client - .list_tools(/*params*/ None, Some(Duration::from_secs(5))) - .await?; - - assert_eq!(http_client.matching_post_attempts(), 2); - assert_eq!(tools, expected_tools); - - Ok(()) -} - -#[tokio::test(flavor = "multi_thread", worker_threads = 1)] -async fn streamable_http_initialize_does_not_retry_non_retryable_status() -> anyhow::Result<()> { - let (_server, base_url) = spawn_streamable_http_server().await?; - - arm_initialize_failure(&base_url, /*status*/ 403, /*remaining*/ 1).await?; - - let error = match create_client(&base_url).await { - Ok(_) => panic!("initialize unexpectedly succeeded after non-retryable HTTP 403"), - Err(error) => error, - }; - assert!(format!("{error:#}").contains("403")); + assert_eq!(http_client.initialize_attempts(), 2); + assert_eq!(result, expected_echo_result("after-init-retry")); Ok(()) } diff --git a/codex-rs/rmcp-client/tests/streamable_http_remote.rs b/codex-rs/rmcp-client/tests/streamable_http_remote.rs index 69f274910d7..0d4690a8255 100644 --- a/codex-rs/rmcp-client/tests/streamable_http_remote.rs +++ b/codex-rs/rmcp-client/tests/streamable_http_remote.rs @@ -6,19 +6,7 @@ mod streamable_http_test_support; -use std::net::SocketAddr; -use std::sync::Arc; -use std::sync::atomic::AtomicUsize; -use std::sync::atomic::Ordering; -use std::time::Duration; - -use anyhow::Context as _; use pretty_assertions::assert_eq; -use tokio::io::AsyncReadExt; -use tokio::io::AsyncWriteExt; -use tokio::net::TcpListener; -use tokio::net::TcpStream; -use tokio::task::JoinHandle; use streamable_http_test_support::call_echo_tool; use streamable_http_test_support::create_remote_client; @@ -47,239 +35,3 @@ async fn streamable_http_remote_client_round_trips_through_exec_server() -> anyh Ok(()) } - -/// What this tests: when a real remote exec-server sees a no-status network -/// failure during the Streamable HTTP initialize request, it maps the reqwest -/// send failure into a JSON-RPC internal server error and the RMCP client still -/// treats that remote-shaped error as retryable. -#[tokio::test(flavor = "multi_thread", worker_threads = 2)] -async fn streamable_http_remote_initialize_retries_no_response_failure() -> anyhow::Result<()> { - let (_server, base_url) = spawn_streamable_http_server().await?; - let proxy = DropNextMcpPostProxy::spawn(&base_url).await?; - proxy.arm_next_mcp_post_drop(); - let exec_server = spawn_exec_server().await?; - - let client = create_remote_client(proxy.base_url(), exec_server.client.clone()).await?; - let result = call_echo_tool(&client, "remote-init-retry").await?; - - assert_eq!(proxy.dropped_mcp_posts(), 1); - assert_eq!(result, expected_echo_result("remote-init-retry")); - - Ok(()) -} - -/// What this tests: once initialized through the real remote exec-server path, -/// a no-status Streamable HTTP failure during tools/list is retried instead of -/// surfacing the remote JSON-RPC internal server error to the caller. -#[tokio::test(flavor = "multi_thread", worker_threads = 2)] -async fn streamable_http_remote_tools_list_retries_no_response_failure() -> anyhow::Result<()> { - let (_server, base_url) = spawn_streamable_http_server().await?; - let proxy = DropNextMcpPostProxy::spawn(&base_url).await?; - let exec_server = spawn_exec_server().await?; - let client = create_remote_client(proxy.base_url(), exec_server.client.clone()).await?; - - proxy.arm_next_mcp_post_drop(); - let tools = client - .list_tools(/*params*/ None, Some(Duration::from_secs(5))) - .await?; - - assert_eq!(proxy.dropped_mcp_posts(), 1); - assert_eq!(tools.tools.len(), 1); - assert_eq!(tools.tools[0].name, "echo"); - - Ok(()) -} - -struct DropNextMcpPostProxy { - base_url: String, - drops_remaining: Arc, - dropped_mcp_posts: Arc, - task: JoinHandle<()>, -} - -impl DropNextMcpPostProxy { - async fn spawn(target_base_url: &str) -> anyhow::Result { - let target_addr = parse_target_addr(target_base_url)?; - let listener = TcpListener::bind("127.0.0.1:0").await?; - let proxy_addr = listener.local_addr()?; - let drops_remaining = Arc::new(AtomicUsize::new(0)); - let dropped_mcp_posts = Arc::new(AtomicUsize::new(0)); - let task_drops_remaining = Arc::clone(&drops_remaining); - let task_dropped_mcp_posts = Arc::clone(&dropped_mcp_posts); - - let task = tokio::spawn(async move { - while let Ok((client, _addr)) = listener.accept().await { - let connection_drops_remaining = Arc::clone(&task_drops_remaining); - let connection_dropped_mcp_posts = Arc::clone(&task_dropped_mcp_posts); - tokio::spawn(async move { - let _ = proxy_connection( - client, - target_addr, - connection_drops_remaining, - connection_dropped_mcp_posts, - ) - .await; - }); - } - }); - - Ok(Self { - base_url: format!("http://{proxy_addr}"), - drops_remaining, - dropped_mcp_posts, - task, - }) - } - - fn base_url(&self) -> &str { - &self.base_url - } - - fn arm_next_mcp_post_drop(&self) { - self.drops_remaining.fetch_add(1, Ordering::SeqCst); - } - - fn dropped_mcp_posts(&self) -> usize { - self.dropped_mcp_posts.load(Ordering::SeqCst) - } -} - -impl Drop for DropNextMcpPostProxy { - fn drop(&mut self) { - self.task.abort(); - } -} - -async fn proxy_connection( - mut client: TcpStream, - target_addr: SocketAddr, - drops_remaining: Arc, - dropped_mcp_posts: Arc, -) -> anyhow::Result<()> { - let request = read_http_message(&mut client).await?; - if request.is_empty() { - return Ok(()); - } - - if is_mcp_post(&request) && consume_drop(&drops_remaining) { - dropped_mcp_posts.fetch_add(1, Ordering::SeqCst); - return Ok(()); - } - - let request = with_connection_close(request)?; - let mut upstream = TcpStream::connect(target_addr).await?; - upstream.write_all(&request).await?; - tokio::io::copy(&mut upstream, &mut client).await?; - client.shutdown().await?; - - Ok(()) -} - -async fn read_http_message(stream: &mut TcpStream) -> anyhow::Result> { - let mut message = Vec::new(); - let mut header_end = None; - let mut chunk = [0_u8; 4096]; - - while header_end.is_none() { - let bytes_read = stream.read(&mut chunk).await?; - if bytes_read == 0 { - return Ok(message); - } - message.extend_from_slice(&chunk[..bytes_read]); - header_end = find_header_end(&message); - } - - let header_end = header_end.context("HTTP message headers were not terminated")?; - let content_length = content_length(&message[..header_end])?; - let message_len = header_end + content_length; - - while message.len() < message_len { - let bytes_read = stream.read(&mut chunk).await?; - if bytes_read == 0 { - anyhow::bail!("HTTP message ended before body was complete"); - } - message.extend_from_slice(&chunk[..bytes_read]); - } - - Ok(message) -} - -fn find_header_end(bytes: &[u8]) -> Option { - bytes - .windows(4) - .position(|window| window == b"\r\n\r\n") - .map(|position| position + 4) -} - -fn content_length(headers: &[u8]) -> anyhow::Result { - let headers = std::str::from_utf8(headers).context("HTTP headers were not UTF-8")?; - for line in headers.lines().skip(1) { - let Some((name, value)) = line.split_once(':') else { - continue; - }; - if name.eq_ignore_ascii_case("content-length") { - return value - .trim() - .parse::() - .context("Content-Length header was not a usize"); - } - } - Ok(0) -} - -fn is_mcp_post(request: &[u8]) -> bool { - let Some(request_line) = std::str::from_utf8(request) - .ok() - .and_then(|request| request.lines().next()) - else { - return false; - }; - let mut parts = request_line.split_whitespace(); - parts.next() == Some("POST") && parts.next() == Some("/mcp") -} - -fn consume_drop(drops_remaining: &AtomicUsize) -> bool { - drops_remaining - .fetch_update(Ordering::SeqCst, Ordering::SeqCst, |remaining| { - remaining.checked_sub(1) - }) - .is_ok() -} - -fn with_connection_close(request: Vec) -> anyhow::Result> { - let header_end = find_header_end(&request).context("HTTP request headers were not complete")?; - let headers = std::str::from_utf8(&request[..header_end]).context("request was not UTF-8")?; - let mut next_request = Vec::with_capacity(request.len() + "Connection: close\r\n".len()); - - for line in headers - .strip_suffix("\r\n\r\n") - .unwrap_or(headers) - .split("\r\n") - { - if line - .split_once(':') - .is_some_and(|(name, _value)| name.eq_ignore_ascii_case("connection")) - { - continue; - } - next_request.extend_from_slice(line.as_bytes()); - next_request.extend_from_slice(b"\r\n"); - } - next_request.extend_from_slice(b"Connection: close\r\n\r\n"); - next_request.extend_from_slice(&request[header_end..]); - - Ok(next_request) -} - -fn parse_target_addr(base_url: &str) -> anyhow::Result { - let url = reqwest::Url::parse(base_url)?; - let host = url - .host_str() - .context("target URL did not include a host")?; - let port = url - .port_or_known_default() - .context("target URL did not include a port")?; - format!("{host}:{port}") - .parse() - .context("target URL did not resolve to a socket address") -} diff --git a/codex-rs/rmcp-client/tests/streamable_http_test_support.rs b/codex-rs/rmcp-client/tests/streamable_http_test_support.rs index cdb2286a1d0..b3cefa4296a 100644 --- a/codex-rs/rmcp-client/tests/streamable_http_test_support.rs +++ b/codex-rs/rmcp-client/tests/streamable_http_test_support.rs @@ -44,14 +44,13 @@ use tokio::process::Child; use tokio::process::Command; use tokio::time::sleep; -const INITIALIZE_FAILURE_CONTROL_PATH: &str = "/test/control/initialize-failure"; const SESSION_POST_FAILURE_CONTROL_PATH: &str = "/test/control/session-post-failure"; fn streamable_http_server_bin() -> Result { codex_utils_cargo_bin::cargo_bin("test_streamable_http_server") } -pub(crate) fn init_params() -> InitializeRequestParams { +fn init_params() -> InitializeRequestParams { let mut capabilities = ClientCapabilities::default(); capabilities.elicitation = Some(ElicitationCapability { form: Some(FormElicitationCapability { @@ -188,24 +187,6 @@ pub(crate) async fn arm_session_post_failure( Ok(()) } -pub(crate) async fn arm_initialize_failure( - base_url: &str, - status: u16, - remaining: usize, -) -> anyhow::Result<()> { - let response = reqwest::Client::new() - .post(format!("{base_url}{INITIALIZE_FAILURE_CONTROL_PATH}")) - .json(&json!({ - "status": status, - "remaining": remaining, - })) - .send() - .await?; - - assert_eq!(response.status(), reqwest::StatusCode::NO_CONTENT); - Ok(()) -} - pub(crate) async fn spawn_streamable_http_server() -> anyhow::Result<(Child, String)> { let listener = TcpListener::bind("127.0.0.1:0")?; let port = listener.local_addr()?.port();