From cbfce9450396e25dbc48cdc1daaa3fb347601959 Mon Sep 17 00:00:00 2001 From: Gabriel Peal Date: Thu, 25 Sep 2025 10:40:06 -0700 Subject: [PATCH 01/23] Works --- codex-rs/Cargo.lock | 202 +++++++++++++++++++- codex-rs/Cargo.toml | 2 + codex-rs/core/Cargo.toml | 1 + codex-rs/core/src/codex.rs | 5 +- codex-rs/core/src/config.rs | 8 + codex-rs/core/src/mcp_connection_manager.rs | 128 ++++++++----- codex-rs/mcp-client/src/main.rs | 5 +- codex-rs/mcp-client/src/mcp_client.rs | 3 +- codex-rs/rmcp-client/Cargo.toml | 33 ++++ codex-rs/rmcp-client/src/lib.rs | 5 + codex-rs/rmcp-client/src/rmcp_client.rs | 160 ++++++++++++++++ codex-rs/rmcp-client/src/utils.rs | 159 +++++++++++++++ codex-rs/tui/src/lib.rs | 5 +- 13 files changed, 653 insertions(+), 63 deletions(-) create mode 100644 codex-rs/rmcp-client/Cargo.toml create mode 100644 codex-rs/rmcp-client/src/lib.rs create mode 100644 codex-rs/rmcp-client/src/rmcp_client.rs create mode 100644 codex-rs/rmcp-client/src/utils.rs diff --git a/codex-rs/Cargo.lock b/codex-rs/Cargo.lock index 5ad73350c9..4a07146b44 100644 --- a/codex-rs/Cargo.lock +++ b/codex-rs/Cargo.lock @@ -488,6 +488,12 @@ version = "0.1.1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "fd16c4719339c4530435d38e511904438d07cce7950afa3718a84ac36c10e89e" +[[package]] +name = "cfg_aliases" +version = "0.2.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "613afe47fcd5fac7ccf1db93babcb082c5994d996f20b8b159f2ad1658eb5724" + [[package]] name = "chrono" version = "0.4.42" @@ -680,6 +686,7 @@ dependencies = [ "codex-file-search", "codex-mcp-client", "codex-protocol", + "codex-rmcp-client", "core_test_support", "dirs", "env-flags", @@ -921,6 +928,20 @@ dependencies = [ "ts-rs", ] +[[package]] +name = "codex-rmcp-client" +version = "0.0.0" +dependencies = [ + "anyhow", + "mcp-types", + "pretty_assertions", + "rmcp", + "serde", + "serde_json", + "tokio", + "tracing", +] + [[package]] name = "codex-tui" version = "0.0.0" @@ -1237,8 +1258,18 @@ version = "0.20.11" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "fc7f46116c46ff9ab3eb1597a45688b6715c6e628b5c133e288e709a29bcb4ee" dependencies = [ - "darling_core", - "darling_macro", + "darling_core 0.20.11", + "darling_macro 0.20.11", +] + +[[package]] +name = "darling" +version = "0.21.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "9cdf337090841a411e2a7f3deb9187445851f91b309c0c0a29e05f74a00a48c0" +dependencies = [ + "darling_core 0.21.3", + "darling_macro 0.21.3", ] [[package]] @@ -1255,13 +1286,38 @@ dependencies = [ "syn 2.0.104", ] +[[package]] +name = "darling_core" +version = "0.21.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "1247195ecd7e3c85f83c8d2a366e4210d588e802133e1e355180a9870b517ea4" +dependencies = [ + "fnv", + "ident_case", + "proc-macro2", + "quote", + "strsim 0.11.1", + "syn 2.0.104", +] + [[package]] name = "darling_macro" version = "0.20.11" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "fc34b93ccb385b40dc71c6fceac4b2ad23662c7eeb248cf10d529b7e055b6ead" dependencies = [ - "darling_core", + "darling_core 0.20.11", + "quote", + "syn 2.0.104", +] + +[[package]] +name = "darling_macro" +version = "0.21.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "d38308df82d1080de0afee5d069fa14b0326a88c14f15c5ccda35b4a6c414c81" +dependencies = [ + "darling_core 0.21.3", "quote", "syn 2.0.104", ] @@ -2462,7 +2518,7 @@ version = "0.3.9" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "435d80800b936787d62688c927b6490e887c7ef5ff9ce922c6c6050fca75eb9a" dependencies = [ - "darling", + "darling 0.20.11", "indoc", "proc-macro2", "quote", @@ -2932,7 +2988,19 @@ checksum = "ab2156c4fce2f8df6c499cc1c763e4394b7482525bf2a9701c9d79d215f519e4" dependencies = [ "bitflags 2.9.1", "cfg-if", - "cfg_aliases", + "cfg_aliases 0.1.1", + "libc", +] + +[[package]] +name = "nix" +version = "0.30.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "74523f3a35e05aba87a1d978330aef40f67b0304ac79c1c00b294c9830543db6" +dependencies = [ + "bitflags 2.9.1", + "cfg-if", + "cfg_aliases 0.2.1", "libc", ] @@ -3358,7 +3426,7 @@ dependencies = [ "lazy_static", "libc", "log", - "nix", + "nix 0.28.0", "serial2", "shared_library", "shell-words", @@ -3446,6 +3514,20 @@ dependencies = [ "unicode-ident", ] +[[package]] +name = "process-wrap" +version = "8.2.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "a3ef4f2f0422f23a82ec9f628ea2acd12871c81a9362b02c43c1aa86acfc3ba1" +dependencies = [ + "futures", + "indexmap 2.10.0", + "nix 0.30.1", + "tokio", + "tracing", + "windows", +] + [[package]] name = "pulldown-cmark" version = "0.10.3" @@ -3713,6 +3795,42 @@ dependencies = [ "windows-sys 0.52.0", ] +[[package]] +name = "rmcp" +version = "0.7.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "534fd1cd0601e798ac30545ff2b7f4a62c6f14edd4aaed1cc5eb1e85f69f09af" +dependencies = [ + "base64", + "chrono", + "futures", + "paste", + "pin-project-lite", + "process-wrap", + "rmcp-macros", + "schemars 1.0.4", + "serde", + "serde_json", + "thiserror 2.0.16", + "tokio", + "tokio-stream", + "tokio-util", + "tracing", +] + +[[package]] +name = "rmcp-macros" +version = "0.7.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "9ba777eb0e5f53a757e36f0e287441da0ab766564ba7201600eeb92a4753022e" +dependencies = [ + "darling 0.21.3", + "proc-macro2", + "quote", + "serde_json", + "syn 2.0.104", +] + [[package]] name = "rustc-demangle" version = "0.1.25" @@ -3798,7 +3916,7 @@ dependencies = [ "libc", "log", "memchr", - "nix", + "nix 0.28.0", "radix_trie", "unicode-segmentation", "unicode-width 0.1.14", @@ -3879,7 +3997,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "3fbf2ae1b8bc8e02df939598064d22402220cd5bbcca1c76f7d6a310974d5615" dependencies = [ "dyn-clone", - "schemars_derive", + "schemars_derive 0.8.22", "serde", "serde_json", ] @@ -3902,8 +4020,10 @@ version = "1.0.4" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "82d20c4491bc164fa2f6c5d44565947a52ad80b9505d8e36f8d54c27c739fcd0" dependencies = [ + "chrono", "dyn-clone", "ref-cast", + "schemars_derive 1.0.4", "serde", "serde_json", ] @@ -3920,6 +4040,18 @@ dependencies = [ "syn 2.0.104", ] +[[package]] +name = "schemars_derive" +version = "1.0.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "33d020396d1d138dc19f1165df7545479dcd58d93810dc5d646a16e55abefa80" +dependencies = [ + "proc-macro2", + "quote", + "serde_derive_internals", + "syn 2.0.104", +] + [[package]] name = "scopeguard" version = "1.2.0" @@ -4071,7 +4203,7 @@ version = "3.14.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "de90945e6565ce0d9a25098082ed4ee4002e047cb59892c318d66821e14bb30f" dependencies = [ - "darling", + "darling 0.20.11", "proc-macro2", "quote", "syn 2.0.104", @@ -5361,6 +5493,28 @@ version = "0.4.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "712e227841d057c1ee1cd2fb22fa7e5a5461ae8e48fa2ca79ec42cfc1931183f" +[[package]] +name = "windows" +version = "0.61.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "9babd3a767a4c1aef6900409f85f5d53ce2544ccdfaa86dad48c91782c6d6893" +dependencies = [ + "windows-collections", + "windows-core", + "windows-future", + "windows-link 0.1.3", + "windows-numerics", +] + +[[package]] +name = "windows-collections" +version = "0.2.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "3beeceb5e5cfd9eb1d76b381630e82c4241ccd0d27f1a39ed41b2760b255c5e8" +dependencies = [ + "windows-core", +] + [[package]] name = "windows-core" version = "0.61.2" @@ -5374,6 +5528,17 @@ dependencies = [ "windows-strings", ] +[[package]] +name = "windows-future" +version = "0.2.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "fc6a41e98427b19fe4b73c550f060b59fa592d7d686537eebf9385621bfbad8e" +dependencies = [ + "windows-core", + "windows-link 0.1.3", + "windows-threading", +] + [[package]] name = "windows-implement" version = "0.60.0" @@ -5408,6 +5573,16 @@ version = "0.2.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "45e46c0661abb7180e7b9c281db115305d49ca1709ab8242adf09666d2173c65" +[[package]] +name = "windows-numerics" +version = "0.2.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "9150af68066c4c5c07ddc0ce30421554771e528bde427614c61038bc2c92c2b1" +dependencies = [ + "windows-core", + "windows-link 0.1.3", +] + [[package]] name = "windows-registry" version = "0.5.3" @@ -5535,6 +5710,15 @@ dependencies = [ "windows_x86_64_msvc 0.53.0", ] +[[package]] +name = "windows-threading" +version = "0.1.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "b66463ad2e0ea3bbf808b7f1d371311c80e115c0b71d60efc142cafbcfb057a6" +dependencies = [ + "windows-link 0.1.3", +] + [[package]] name = "windows_aarch64_gnullvm" version = "0.42.2" diff --git a/codex-rs/Cargo.toml b/codex-rs/Cargo.toml index 7b4db5fcc5..d4c10ff01c 100644 --- a/codex-rs/Cargo.toml +++ b/codex-rs/Cargo.toml @@ -18,6 +18,7 @@ members = [ "ollama", "protocol", "protocol-ts", + "rmcp-client", "tui", "utils/readiness", ] @@ -48,6 +49,7 @@ codex-mcp-client = { path = "mcp-client" } codex-mcp-server = { path = "mcp-server" } codex-ollama = { path = "ollama" } codex-protocol = { path = "protocol" } +codex-rmcp-client = { path = "rmcp-client" } codex-protocol-ts = { path = "protocol-ts" } codex-tui = { path = "tui" } codex-utils-readiness = { path = "utils/readiness" } diff --git a/codex-rs/core/Cargo.toml b/codex-rs/core/Cargo.toml index d9ded08283..f1d16dc405 100644 --- a/codex-rs/core/Cargo.toml +++ b/codex-rs/core/Cargo.toml @@ -21,6 +21,7 @@ chrono = { workspace = true, features = ["serde"] } codex-apply-patch = { workspace = true } codex-file-search = { workspace = true } codex-mcp-client = { workspace = true } +codex-rmcp-client = { workspace = true } codex-protocol = { workspace = true } dirs = { workspace = true } env-flags = { workspace = true } diff --git a/codex-rs/core/src/codex.rs b/codex-rs/core/src/codex.rs index 48aa3dbd7c..052b5f7ed5 100644 --- a/codex-rs/core/src/codex.rs +++ b/codex-rs/core/src/codex.rs @@ -377,7 +377,10 @@ impl Session { // - load history metadata let rollout_fut = RolloutRecorder::new(&config, rollout_params); - let mcp_fut = McpConnectionManager::new(config.mcp_servers.clone()); + let mcp_fut = McpConnectionManager::new( + config.mcp_servers.clone(), + config.experimental_use_rmcp_client, + ); let default_shell_fut = shell::default_user_shell(); let history_meta_fut = crate::message_history::history_metadata(&config); diff --git a/codex-rs/core/src/config.rs b/codex-rs/core/src/config.rs index 508a3dc36f..66e333283e 100644 --- a/codex-rs/core/src/config.rs +++ b/codex-rs/core/src/config.rs @@ -184,6 +184,8 @@ pub struct Config { /// If set to `true`, used only the experimental unified exec tool. pub use_experimental_unified_exec_tool: bool, + pub experimental_use_rmcp_client: bool, + /// Include the `view_image` tool that lets the agent attach a local image path to context. pub include_view_image_tool: bool, @@ -693,6 +695,7 @@ pub struct ConfigToml { pub experimental_use_exec_command_tool: Option, pub experimental_use_unified_exec_tool: Option, + pub experimental_use_rmcp_client: Option, pub projects: Option>, @@ -1043,6 +1046,7 @@ impl Config { use_experimental_unified_exec_tool: cfg .experimental_use_unified_exec_tool .unwrap_or(false), + experimental_use_rmcp_client: cfg.experimental_use_rmcp_client.unwrap_or(false), include_view_image_tool, active_profile: active_profile_name, disable_paste_burst: cfg.disable_paste_burst.unwrap_or(false), @@ -1651,6 +1655,7 @@ model_verbosity = "high" tools_web_search_request: false, use_experimental_streamable_shell_tool: false, use_experimental_unified_exec_tool: false, + experimental_use_rmcp_client: false, include_view_image_tool: true, active_profile: Some("o3".to_string()), disable_paste_burst: false, @@ -1709,6 +1714,7 @@ model_verbosity = "high" tools_web_search_request: false, use_experimental_streamable_shell_tool: false, use_experimental_unified_exec_tool: false, + experimental_use_rmcp_client: false, include_view_image_tool: true, active_profile: Some("gpt3".to_string()), disable_paste_burst: false, @@ -1782,6 +1788,7 @@ model_verbosity = "high" tools_web_search_request: false, use_experimental_streamable_shell_tool: false, use_experimental_unified_exec_tool: false, + experimental_use_rmcp_client: false, include_view_image_tool: true, active_profile: Some("zdr".to_string()), disable_paste_burst: false, @@ -1841,6 +1848,7 @@ model_verbosity = "high" tools_web_search_request: false, use_experimental_streamable_shell_tool: false, use_experimental_unified_exec_tool: false, + experimental_use_rmcp_client: false, include_view_image_tool: true, active_profile: Some("gpt5".to_string()), disable_paste_burst: false, diff --git a/codex-rs/core/src/mcp_connection_manager.rs b/codex-rs/core/src/mcp_connection_manager.rs index e9c95fc80b..c0e984d66d 100644 --- a/codex-rs/core/src/mcp_connection_manager.rs +++ b/codex-rs/core/src/mcp_connection_manager.rs @@ -16,6 +16,7 @@ use anyhow::Context; use anyhow::Result; use anyhow::anyhow; use codex_mcp_client::McpClient; +use codex_rmcp_client::RmcpClient; use mcp_types::ClientCapabilities; use mcp_types::Implementation; use mcp_types::Tool; @@ -86,11 +87,42 @@ struct ToolInfo { } struct ManagedClient { - client: Arc, + client: McpClientAdapter, startup_timeout: Duration, tool_timeout: Option, } +#[derive(Clone)] +enum McpClientAdapter { + Legacy(Arc), + Rmcp(Arc), +} + +impl McpClientAdapter { + async fn list_tools( + &self, + params: Option, + timeout: Option, + ) -> Result { + match self { + McpClientAdapter::Legacy(client) => client.list_tools(params, timeout).await, + McpClientAdapter::Rmcp(client) => client.list_tools(params, timeout).await, + } + } + + async fn call_tool( + &self, + name: String, + arguments: Option, + timeout: Option, + ) -> Result { + match self { + McpClientAdapter::Legacy(client) => client.call_tool(name, arguments, timeout).await, + McpClientAdapter::Rmcp(client) => client.call_tool(name, arguments, timeout).await, + } + } +} + /// A thin wrapper around a set of running [`McpClient`] instances. #[derive(Default)] pub(crate) struct McpConnectionManager { @@ -115,6 +147,7 @@ impl McpConnectionManager { /// user should be informed about these errors. pub async fn new( mcp_servers: HashMap, + use_rmcp_client: bool, ) -> Result<(Self, ClientStartErrors)> { // Early exit if no servers are configured. if mcp_servers.is_empty() { @@ -140,54 +173,59 @@ impl McpConnectionManager { let tool_timeout = cfg.tool_timeout_sec.unwrap_or(DEFAULT_TOOL_TIMEOUT); + let use_rmcp_client_flag = use_rmcp_client; join_set.spawn(async move { let McpServerConfig { command, args, env, .. } = cfg; - let client_res = McpClient::new_stdio_client( - command.into(), - args.into_iter().map(OsString::from).collect(), - env, - ) - .await; - match client_res { - Ok(client) => { - // Initialize the client. - let params = mcp_types::InitializeRequestParams { - capabilities: ClientCapabilities { - experimental: None, - roots: None, - sampling: None, - // https://modelcontextprotocol.io/specification/2025-06-18/client/elicitation#capabilities - // indicates this should be an empty object. - elicitation: Some(json!({})), - }, - client_info: Implementation { - name: "codex-mcp-client".to_owned(), - version: env!("CARGO_PKG_VERSION").to_owned(), - title: Some("Codex".into()), - // This field is used by Codex when it is an MCP - // server: it should not be used when Codex is - // an MCP client. - user_agent: None, - }, - protocol_version: mcp_types::MCP_SCHEMA_VERSION.to_owned(), - }; - let initialize_notification_params = None; - let init_result = client - .initialize( - params, - initialize_notification_params, - Some(startup_timeout), - ) - .await; - ( - (server_name, tool_timeout), - init_result.map(|_| (client, startup_timeout)), - ) + let command_os = OsString::from(command); + let args_os: Vec = args.into_iter().map(OsString::from).collect(); + let params = mcp_types::InitializeRequestParams { + capabilities: ClientCapabilities { + experimental: None, + roots: None, + sampling: None, + // https://modelcontextprotocol.io/specification/2025-06-18/client/elicitation#capabilities + // indicates this should be an empty object. + elicitation: Some(json!({})), + }, + client_info: Implementation { + name: "codex-mcp-client".to_owned(), + version: env!("CARGO_PKG_VERSION").to_owned(), + title: Some("Codex".into()), + // This field is used by Codex when it is an MCP + // server: it should not be used when Codex is + // an MCP client. + user_agent: None, + }, + protocol_version: mcp_types::MCP_SCHEMA_VERSION.to_owned(), + }; + + let client = if use_rmcp_client_flag { + match RmcpClient::new_stdio_client(command_os, args_os, env).await { + Ok(client) => { + let client = Arc::new(client); + match client.initialize(params, Some(startup_timeout)).await { + Ok(_) => Ok((McpClientAdapter::Rmcp(client), startup_timeout)), + Err(e) => Err(e), + } + } + Err(e) => Err(e.into()), } - Err(e) => ((server_name, tool_timeout), Err(e.into())), - } + } else { + match McpClient::new_stdio_client(command_os, args_os, env).await { + Ok(client) => { + let client = Arc::new(client); + match client.initialize(params, Some(startup_timeout)).await { + Ok(_) => Ok((McpClientAdapter::Legacy(client), startup_timeout)), + Err(e) => Err(e), + } + } + Err(e) => Err(e.into()), + } + }; + + ((server_name, tool_timeout), client) }); } @@ -207,7 +245,7 @@ impl McpConnectionManager { clients.insert( server_name, ManagedClient { - client: Arc::new(client), + client, startup_timeout, tool_timeout: Some(tool_timeout), }, diff --git a/codex-rs/mcp-client/src/main.rs b/codex-rs/mcp-client/src/main.rs index d25bca4ba3..f46058b99e 100644 --- a/codex-rs/mcp-client/src/main.rs +++ b/codex-rs/mcp-client/src/main.rs @@ -70,11 +70,8 @@ async fn main() -> Result<()> { }, protocol_version: MCP_SCHEMA_VERSION.to_owned(), }; - let initialize_notification_params = None; let timeout = Some(Duration::from_secs(10)); - let response = client - .initialize(params, initialize_notification_params, timeout) - .await?; + let response = client.initialize(params, timeout).await?; eprintln!("initialize response: {response:?}"); // Issue `tools/list` request (no params). diff --git a/codex-rs/mcp-client/src/mcp_client.rs b/codex-rs/mcp-client/src/mcp_client.rs index 505df6bd4e..087335e66b 100644 --- a/codex-rs/mcp-client/src/mcp_client.rs +++ b/codex-rs/mcp-client/src/mcp_client.rs @@ -315,13 +315,12 @@ impl McpClient { pub async fn initialize( &self, initialize_params: InitializeRequestParams, - initialize_notification_params: Option, timeout: Option, ) -> Result { let response = self .send_request::(initialize_params, timeout) .await?; - self.send_notification::(initialize_notification_params) + self.send_notification::(None) .await?; Ok(response) } diff --git a/codex-rs/rmcp-client/Cargo.toml b/codex-rs/rmcp-client/Cargo.toml new file mode 100644 index 0000000000..52126330e8 --- /dev/null +++ b/codex-rs/rmcp-client/Cargo.toml @@ -0,0 +1,33 @@ +[package] +name = "codex-rmcp-client" +version = { workspace = true } +edition = "2024" + +[lints] +workspace = true + +[dependencies] +anyhow = "1" +mcp-types = { path = "../mcp-types" } +rmcp = { version = "0.7.0", default-features = false, features = [ + "base64", + "client", + "macros", + "schemars", + "server", + "transport-child-process", +] } +serde = { version = "1", features = ["derive"] } +serde_json = "1" +tokio = { version = "1", features = [ + "io-util", + "macros", + "process", + "rt-multi-thread", + "sync", + "time", +] } +tracing = { version = "0.1.41", features = ["log"] } + +[dev-dependencies] +pretty_assertions = "1.4.1" diff --git a/codex-rs/rmcp-client/src/lib.rs b/codex-rs/rmcp-client/src/lib.rs new file mode 100644 index 0000000000..ef5088406c --- /dev/null +++ b/codex-rs/rmcp-client/src/lib.rs @@ -0,0 +1,5 @@ +mod logging_client_handler; +mod rmcp_client; +mod utils; + +pub use rmcp_client::RmcpClient; diff --git a/codex-rs/rmcp-client/src/rmcp_client.rs b/codex-rs/rmcp-client/src/rmcp_client.rs new file mode 100644 index 0000000000..9a94298195 --- /dev/null +++ b/codex-rs/rmcp-client/src/rmcp_client.rs @@ -0,0 +1,160 @@ +use std::collections::HashMap; +use std::ffi::OsString; +use std::io; +use std::process::Stdio; +use std::sync::Arc; +use std::time::Duration; + +use anyhow::Result; +use anyhow::anyhow; +use mcp_types::CallToolRequestParams; +use mcp_types::CallToolResult; +use mcp_types::InitializeRequestParams; +use mcp_types::InitializeResult; +use mcp_types::ListToolsRequestParams; +use mcp_types::ListToolsResult; +use rmcp::model::CallToolRequestParam; +use rmcp::model::InitializeRequestParam; +use rmcp::model::PaginatedRequestParam; +use rmcp::service::RoleClient; +use rmcp::service::RunningService; +use rmcp::service::{self}; +use rmcp::transport::child_process::TokioChildProcess; +use tokio::process::Command; +use tokio::sync::Mutex; +use tokio::time; + +use crate::logging_client_handler::LoggingClientHandler; +use crate::utils::convert_call_tool_result; +use crate::utils::convert_to_mcp; +use crate::utils::convert_to_rmcp; +use crate::utils::create_env_for_mcp_server; +use crate::utils::run_with_timeout; + +enum ClientState { + Connecting { + transport: Option, + }, + Ready { + service: Arc>, + }, +} + +/// MCP client implemented on top of the official `rmcp` SDK. +/// https://github.com/modelcontextprotocol/rust-sdk +pub struct RmcpClient { + state: Mutex, +} + +impl RmcpClient { + pub async fn new_stdio_client( + program: OsString, + args: Vec, + env: Option>, + ) -> io::Result { + let mut command = Command::new(program); + command + .kill_on_drop(true) + .stdin(Stdio::piped()) + .stdout(Stdio::piped()) + .stderr(Stdio::null()) + .env_clear() + .envs(create_env_for_mcp_server(env)) + .args(args); + + let (transport, _stderr) = TokioChildProcess::builder(command).spawn()?; + + Ok(Self { + state: Mutex::new(ClientState::Connecting { + transport: Some(transport), + }), + }) + } + + /// Perform the initialization handshake with the MCP server. + /// https://modelcontextprotocol.io/specification/2025-06-18/basic/lifecycle#initialization + pub async fn initialize( + &self, + params: InitializeRequestParams, + timeout: Option, + ) -> Result { + let transport = { + let mut guard = self.state.lock().await; + match &mut *guard { + ClientState::Connecting { transport } => transport + .take() + .ok_or_else(|| anyhow!("client already initializing"))?, + ClientState::Ready { .. } => { + return Err(anyhow!("client already initialized")); + } + } + }; + + let client_info = convert_to_rmcp::<_, InitializeRequestParam>(params.clone())?; + let client_handler = LoggingClientHandler::new(client_info); + let service_future = service::serve_client(client_handler, transport); + + let service = match timeout { + Some(duration) => time::timeout(duration, service_future) + .await + .map_err(|_| anyhow!("timed out handshaking with MCP server after {duration:?}"))? + .map_err(|err| anyhow!("handshaking with MCP server failed: {err}"))?, + None => service_future + .await + .map_err(|err| anyhow!("handshaking with MCP server failed: {err}"))?, + }; + + let initialize_result_rmcp = service + .peer() + .peer_info() + .ok_or_else(|| anyhow!("handshake succeeded but server info was missing"))?; + let initialize_result = convert_to_mcp(initialize_result_rmcp)?; + + { + let mut guard = self.state.lock().await; + *guard = ClientState::Ready { + service: Arc::new(service), + }; + } + + Ok(initialize_result) + } + + pub async fn list_tools( + &self, + params: Option, + timeout: Option, + ) -> Result { + let service = self.service().await?; + let rmcp_params = match params { + Some(p) => Some(convert_to_rmcp::<_, PaginatedRequestParam>(p)?), + None => None, + }; + + let fut = service.list_tools(rmcp_params); + let result = run_with_timeout(fut, timeout, "tools/list").await?; + convert_to_mcp(result) + } + + pub async fn call_tool( + &self, + name: String, + arguments: Option, + timeout: Option, + ) -> Result { + let service = self.service().await?; + let params = CallToolRequestParams { arguments, name }; + let rmcp_params: CallToolRequestParam = convert_to_rmcp(params)?; + let fut = service.call_tool(rmcp_params); + let rmcp_result = run_with_timeout(fut, timeout, "tools/call").await?; + convert_call_tool_result(rmcp_result) + } + + async fn service(&self) -> Result>> { + let guard = self.state.lock().await; + match &*guard { + ClientState::Ready { service } => Ok(Arc::clone(service)), + ClientState::Connecting { .. } => Err(anyhow!("MCP client not initialized")), + } + } +} diff --git a/codex-rs/rmcp-client/src/utils.rs b/codex-rs/rmcp-client/src/utils.rs new file mode 100644 index 0000000000..22b9daf477 --- /dev/null +++ b/codex-rs/rmcp-client/src/utils.rs @@ -0,0 +1,159 @@ +use std::collections::HashMap; +use std::env; +use std::time::Duration; + +use anyhow::Context; +use anyhow::Result; +use anyhow::anyhow; +use mcp_types::CallToolResult; +use rmcp::model::CallToolResult as RmcpCallToolResult; +use rmcp::service::ServiceError; +use serde_json::Value; +use tokio::time; + +pub(crate) async fn run_with_timeout( + fut: F, + timeout: Option, + label: &str, +) -> Result +where + F: std::future::Future>, +{ + if let Some(duration) = timeout { + let result = time::timeout(duration, fut) + .await + .map_err(|_| anyhow!("timed out awaiting {label} after {duration:?}"))?; + result.map_err(|err| anyhow!("{label} failed: {err}")) + } else { + fut.await.map_err(|err| anyhow!("{label} failed: {err}")) + } +} + +pub(crate) fn convert_call_tool_result(result: RmcpCallToolResult) -> Result { + let mut value = serde_json::to_value(result)?; + if let Some(obj) = value.as_object_mut() + && (obj.get("content").is_none() || obj.get("content").is_some_and(serde_json::Value::is_null)) + { + obj.insert("content".to_string(), Value::Array(Vec::new())); + } + serde_json::from_value(value).context("failed to convert call tool result") +} + +/// Convert from mcp-types to Rust SDK types. +/// +/// The Rust SDK types are the same as our mcp-types crate because they are both +/// derived from the same MCP specification. +/// As a result, it should be safe to convert directly from one to the other. +pub(crate) fn convert_to_rmcp(value: T) -> Result +where + T: serde::Serialize, + U: serde::de::DeserializeOwned, +{ + let json = serde_json::to_value(value)?; + serde_json::from_value(json).map_err(|err| anyhow!(err)) +} + +/// Convert from Rust SDK types to mcp-types. +/// +/// The Rust SDK types are the same as our mcp-types crate because they are both +/// derived from the same MCP specification. +/// As a result, it should be safe to convert directly from one to the other. +pub(crate) fn convert_to_mcp(value: T) -> Result +where + T: serde::Serialize, + U: serde::de::DeserializeOwned, +{ + let json = serde_json::to_value(value)?; + serde_json::from_value(json).map_err(|err| anyhow!(err)) +} + +pub(crate) fn create_env_for_mcp_server( + extra_env: Option>, +) -> HashMap { + DEFAULT_ENV_VARS + .iter() + .filter_map(|var| env::var(var).ok().map(|value| (var.to_string(), value))) + .chain(extra_env.unwrap_or_default()) + .collect() +} + +#[cfg(unix)] +pub(crate) const DEFAULT_ENV_VARS: &[&str] = &[ + "HOME", + "LOGNAME", + "PATH", + "SHELL", + "USER", + "__CF_USER_TEXT_ENCODING", + "LANG", + "LC_ALL", + "TERM", + "TMPDIR", + "TZ", +]; + +#[cfg(windows)] +pub(crate) const DEFAULT_ENV_VARS: &[&str] = &[ + "PATH", + "PATHEXT", + "USERNAME", + "USERDOMAIN", + "USERPROFILE", + "TEMP", + "TMP", +]; + +#[cfg(test)] +mod tests { + use super::*; + use mcp_types::ContentBlock; + use pretty_assertions::assert_eq; + use rmcp::model::CallToolResult as RmcpCallToolResult; + use serde_json::json; + + #[tokio::test] + async fn create_env_honors_overrides() { + let value = "custom".to_string(); + let env = create_env_for_mcp_server(Some(HashMap::from([("TZ".into(), value.clone())]))); + assert_eq!(env.get("TZ"), Some(&value)); + } + + #[test] + fn convert_call_tool_result_defaults_missing_content() -> Result<()> { + let structured_content = json!({ "key": "value" }); + let rmcp_result = RmcpCallToolResult { + content: vec![], + structured_content: Some(structured_content.clone()), + is_error: Some(true), + meta: None, + }; + + let result = convert_call_tool_result(rmcp_result)?; + + assert!(result.content.is_empty()); + assert_eq!(result.structured_content, Some(structured_content)); + assert_eq!(result.is_error, Some(true)); + + Ok(()) + } + + #[test] + fn convert_call_tool_result_preserves_existing_content() -> Result<()> { + let rmcp_result = RmcpCallToolResult::success(vec![rmcp::model::Content::text("hello")]); + + let result = convert_call_tool_result(rmcp_result)?; + + assert_eq!(result.content.len(), 1); + match &result.content[0] { + ContentBlock::TextContent(text_content) => { + assert_eq!(text_content.text, "hello"); + assert_eq!(text_content.r#type, "text"); + } + other => panic!("expected text content got {other:?}"), + } + assert_eq!(result.structured_content, None); + assert_eq!(result.is_error, Some(false)); + + Ok(()) + } +} diff --git a/codex-rs/tui/src/lib.rs b/codex-rs/tui/src/lib.rs index 1453971ca4..f563a7c0ed 100644 --- a/codex-rs/tui/src/lib.rs +++ b/codex-rs/tui/src/lib.rs @@ -223,8 +223,9 @@ pub async fn run_main( // use RUST_LOG env var, default to info for codex crates. let env_filter = || { - EnvFilter::try_from_default_env() - .unwrap_or_else(|_| EnvFilter::new("codex_core=info,codex_tui=info")) + EnvFilter::try_from_default_env().unwrap_or_else(|_| { + EnvFilter::new("codex_core=info,codex_tui=info,codex_rmcp_client=info") + }) }; // Build layered subscriber: From 9bc2d9c5a6d29f3d128e8a9b78908d293d03fb99 Mon Sep 17 00:00:00 2001 From: Gabriel Peal Date: Thu, 25 Sep 2025 14:36:45 -0700 Subject: [PATCH 02/23] logging client handler + stderr test --- codex-rs/core/src/mcp_connection_manager.rs | 57 ++++---- .../rmcp-client/src/logging_client_handler.rs | 134 ++++++++++++++++++ codex-rs/rmcp-client/src/rmcp_client.rs | 34 ++++- codex-rs/rmcp-client/src/utils.rs | 3 +- 4 files changed, 196 insertions(+), 32 deletions(-) create mode 100644 codex-rs/rmcp-client/src/logging_client_handler.rs diff --git a/codex-rs/core/src/mcp_connection_manager.rs b/codex-rs/core/src/mcp_connection_manager.rs index c0e984d66d..5ec4f05757 100644 --- a/codex-rs/core/src/mcp_connection_manager.rs +++ b/codex-rs/core/src/mcp_connection_manager.rs @@ -99,6 +99,25 @@ enum McpClientAdapter { } impl McpClientAdapter { + async fn new_stdio_client( + use_rmcp_client: bool, + program: OsString, + args: Vec, + env: Option>, + params: mcp_types::InitializeRequestParams, + startup_timeout: Duration, + ) -> Result { + if use_rmcp_client { + let client = Arc::new(RmcpClient::new_stdio_client(program, args, env).await?); + client.initialize(params, Some(startup_timeout)).await?; + Ok(McpClientAdapter::Rmcp(client)) + } else { + let client = Arc::new(McpClient::new_stdio_client(program, args, env).await?); + client.initialize(params, Some(startup_timeout)).await?; + Ok(McpClientAdapter::Legacy(client)) + } + } + async fn list_tools( &self, params: Option, @@ -170,7 +189,6 @@ impl McpConnectionManager { } let startup_timeout = cfg.startup_timeout_sec.unwrap_or(DEFAULT_STARTUP_TIMEOUT); - let tool_timeout = cfg.tool_timeout_sec.unwrap_or(DEFAULT_TOOL_TIMEOUT); let use_rmcp_client_flag = use_rmcp_client; @@ -178,8 +196,8 @@ impl McpConnectionManager { let McpServerConfig { command, args, env, .. } = cfg; - let command_os = OsString::from(command); - let args_os: Vec = args.into_iter().map(OsString::from).collect(); + let command_os: OsString = command.into(); + let args_os: Vec = args.into_iter().map(Into::into).collect(); let params = mcp_types::InitializeRequestParams { capabilities: ClientCapabilities { experimental: None, @@ -201,29 +219,16 @@ impl McpConnectionManager { protocol_version: mcp_types::MCP_SCHEMA_VERSION.to_owned(), }; - let client = if use_rmcp_client_flag { - match RmcpClient::new_stdio_client(command_os, args_os, env).await { - Ok(client) => { - let client = Arc::new(client); - match client.initialize(params, Some(startup_timeout)).await { - Ok(_) => Ok((McpClientAdapter::Rmcp(client), startup_timeout)), - Err(e) => Err(e), - } - } - Err(e) => Err(e.into()), - } - } else { - match McpClient::new_stdio_client(command_os, args_os, env).await { - Ok(client) => { - let client = Arc::new(client); - match client.initialize(params, Some(startup_timeout)).await { - Ok(_) => Ok((McpClientAdapter::Legacy(client), startup_timeout)), - Err(e) => Err(e), - } - } - Err(e) => Err(e.into()), - } - }; + let client = McpClientAdapter::new_stdio_client( + use_rmcp_client_flag, + command_os, + args_os, + env, + params, + startup_timeout, + ) + .await + .map(|c| (c, startup_timeout)); ((server_name, tool_timeout), client) }); diff --git a/codex-rs/rmcp-client/src/logging_client_handler.rs b/codex-rs/rmcp-client/src/logging_client_handler.rs new file mode 100644 index 0000000000..85d237b0e9 --- /dev/null +++ b/codex-rs/rmcp-client/src/logging_client_handler.rs @@ -0,0 +1,134 @@ +use rmcp::ClientHandler; +use rmcp::RoleClient; +use rmcp::model::CancelledNotificationParam; +use rmcp::model::ClientInfo; +use rmcp::model::CreateElicitationRequestParam; +use rmcp::model::CreateElicitationResult; +use rmcp::model::ElicitationAction; +use rmcp::model::LoggingLevel; +use rmcp::model::LoggingMessageNotificationParam; +use rmcp::model::ProgressNotificationParam; +use rmcp::model::ResourceUpdatedNotificationParam; +use rmcp::service::NotificationContext; +use rmcp::service::RequestContext; +use tracing::debug; +use tracing::error; +use tracing::info; +use tracing::warn; + +#[derive(Debug, Clone)] +pub(crate) struct LoggingClientHandler { + client_info: ClientInfo, +} + +impl LoggingClientHandler { + pub(crate) fn new(client_info: ClientInfo) -> Self { + Self { client_info } + } +} + +impl ClientHandler for LoggingClientHandler { + // TODO (CODEX-3571): support elicitations. + async fn create_elicitation( + &self, + request: CreateElicitationRequestParam, + _context: RequestContext, + ) -> Result { + info!( + "MCP server requested elicitation ({}). Elicitations are not supported yet. Declining.", + request.message + ); + Ok(CreateElicitationResult { + action: ElicitationAction::Decline, + content: None, + }) + } + + async fn on_cancelled( + &self, + params: CancelledNotificationParam, + _context: NotificationContext, + ) { + info!( + "MCP server cancelled request (request_id: {}, reason: {:?})", + params.request_id, params.reason + ); + } + + async fn on_progress( + &self, + params: ProgressNotificationParam, + _context: NotificationContext, + ) { + info!( + "MCP server progress notification (token: {:?}, progress: {}, total: {:?}, message: {:?})", + params.progress_token, params.progress, params.total, params.message + ); + } + + async fn on_resource_updated( + &self, + params: ResourceUpdatedNotificationParam, + _context: NotificationContext, + ) { + info!("MCP server resource updated (uri: {})", params.uri); + } + + async fn on_resource_list_changed(&self, _context: NotificationContext) { + info!("MCP server resource list changed"); + } + + async fn on_tool_list_changed(&self, _context: NotificationContext) { + info!("MCP server tool list changed"); + } + + async fn on_prompt_list_changed(&self, _context: NotificationContext) { + info!("MCP server prompt list changed"); + } + + fn get_info(&self) -> ClientInfo { + self.client_info.clone() + } + + async fn on_logging_message( + &self, + params: LoggingMessageNotificationParam, + _context: NotificationContext, + ) { + let LoggingMessageNotificationParam { + level, + logger, + data, + } = params; + let logger = logger.as_deref(); + match level { + LoggingLevel::Emergency + | LoggingLevel::Alert + | LoggingLevel::Critical + | LoggingLevel::Error => { + error!( + "MCP server log message (level: {:?}, logger: {:?}, data: {})", + level, logger, data + ); + } + LoggingLevel::Warning => { + warn!( + "MCP server log message (level: {:?}, logger: {:?}, data: {})", + level, logger, data + ); + } + LoggingLevel::Notice | LoggingLevel::Info => { + info!( + "MCP server log message (level: {:?}, logger: {:?}, data: {})", + level, logger, data + ); + } + LoggingLevel::Debug => { + debug!( + "MCP server log message (level: {:?}, logger: {:?}, data: {})", + level, logger, data + ); + } + } + } +} diff --git a/codex-rs/rmcp-client/src/rmcp_client.rs b/codex-rs/rmcp-client/src/rmcp_client.rs index 9a94298195..472aa21e6b 100644 --- a/codex-rs/rmcp-client/src/rmcp_client.rs +++ b/codex-rs/rmcp-client/src/rmcp_client.rs @@ -20,9 +20,13 @@ use rmcp::service::RoleClient; use rmcp::service::RunningService; use rmcp::service::{self}; use rmcp::transport::child_process::TokioChildProcess; +use tokio::io::AsyncBufReadExt; +use tokio::io::BufReader; use tokio::process::Command; use tokio::sync::Mutex; use tokio::time; +use tracing::info; +use tracing::warn; use crate::logging_client_handler::LoggingClientHandler; use crate::utils::convert_call_tool_result; @@ -52,17 +56,37 @@ impl RmcpClient { args: Vec, env: Option>, ) -> io::Result { - let mut command = Command::new(program); + let program_name = program.to_string_lossy().into_owned(); + let mut command = Command::new(&program); command .kill_on_drop(true) .stdin(Stdio::piped()) .stdout(Stdio::piped()) - .stderr(Stdio::null()) .env_clear() .envs(create_env_for_mcp_server(env)) - .args(args); - - let (transport, _stderr) = TokioChildProcess::builder(command).spawn()?; + .args(&args); + + let (transport, stderr) = TokioChildProcess::builder(command) + .stderr(Stdio::piped()) + .spawn()?; + + if let Some(stderr) = stderr { + tokio::spawn(async move { + let mut reader = BufReader::new(stderr).lines(); + loop { + match reader.next_line().await { + Ok(Some(line)) => { + info!("MCP server stderr ({program_name}): {line}"); + } + Ok(None) => break, + Err(error) => { + warn!("Failed to read MCP server stderr ({program_name}): {error}"); + break; + } + } + } + }); + } Ok(Self { state: Mutex::new(ClientState::Connecting { diff --git a/codex-rs/rmcp-client/src/utils.rs b/codex-rs/rmcp-client/src/utils.rs index 22b9daf477..500fe9bc6c 100644 --- a/codex-rs/rmcp-client/src/utils.rs +++ b/codex-rs/rmcp-client/src/utils.rs @@ -32,7 +32,8 @@ where pub(crate) fn convert_call_tool_result(result: RmcpCallToolResult) -> Result { let mut value = serde_json::to_value(result)?; if let Some(obj) = value.as_object_mut() - && (obj.get("content").is_none() || obj.get("content").is_some_and(serde_json::Value::is_null)) + && (obj.get("content").is_none() + || obj.get("content").is_some_and(serde_json::Value::is_null)) { obj.insert("content".to_string(), Value::Array(Vec::new())); } From 2bee4fa8cf869d56c784edb8033dc4cec5e85624 Mon Sep 17 00:00:00 2001 From: Gabriel Peal Date: Thu, 25 Sep 2025 14:57:15 -0700 Subject: [PATCH 03/23] use_* --- codex-rs/core/src/codex.rs | 2 +- codex-rs/core/src/config.rs | 12 ++++++------ 2 files changed, 7 insertions(+), 7 deletions(-) diff --git a/codex-rs/core/src/codex.rs b/codex-rs/core/src/codex.rs index 052b5f7ed5..50c72ec806 100644 --- a/codex-rs/core/src/codex.rs +++ b/codex-rs/core/src/codex.rs @@ -379,7 +379,7 @@ impl Session { let mcp_fut = McpConnectionManager::new( config.mcp_servers.clone(), - config.experimental_use_rmcp_client, + config.use_experimental_use_rmcp_client, ); let default_shell_fut = shell::default_user_shell(); let history_meta_fut = crate::message_history::history_metadata(&config); diff --git a/codex-rs/core/src/config.rs b/codex-rs/core/src/config.rs index 66e333283e..dc5f43036f 100644 --- a/codex-rs/core/src/config.rs +++ b/codex-rs/core/src/config.rs @@ -184,7 +184,7 @@ pub struct Config { /// If set to `true`, used only the experimental unified exec tool. pub use_experimental_unified_exec_tool: bool, - pub experimental_use_rmcp_client: bool, + pub use_experimental_use_rmcp_client: bool, /// Include the `view_image` tool that lets the agent attach a local image path to context. pub include_view_image_tool: bool, @@ -1046,7 +1046,7 @@ impl Config { use_experimental_unified_exec_tool: cfg .experimental_use_unified_exec_tool .unwrap_or(false), - experimental_use_rmcp_client: cfg.experimental_use_rmcp_client.unwrap_or(false), + use_experimental_use_rmcp_client: cfg.experimental_use_rmcp_client.unwrap_or(false), include_view_image_tool, active_profile: active_profile_name, disable_paste_burst: cfg.disable_paste_burst.unwrap_or(false), @@ -1655,7 +1655,7 @@ model_verbosity = "high" tools_web_search_request: false, use_experimental_streamable_shell_tool: false, use_experimental_unified_exec_tool: false, - experimental_use_rmcp_client: false, + use_experimental_use_rmcp_client: false, include_view_image_tool: true, active_profile: Some("o3".to_string()), disable_paste_burst: false, @@ -1714,7 +1714,7 @@ model_verbosity = "high" tools_web_search_request: false, use_experimental_streamable_shell_tool: false, use_experimental_unified_exec_tool: false, - experimental_use_rmcp_client: false, + use_experimental_use_rmcp_client: false, include_view_image_tool: true, active_profile: Some("gpt3".to_string()), disable_paste_burst: false, @@ -1788,7 +1788,7 @@ model_verbosity = "high" tools_web_search_request: false, use_experimental_streamable_shell_tool: false, use_experimental_unified_exec_tool: false, - experimental_use_rmcp_client: false, + use_experimental_use_rmcp_client: false, include_view_image_tool: true, active_profile: Some("zdr".to_string()), disable_paste_burst: false, @@ -1848,7 +1848,7 @@ model_verbosity = "high" tools_web_search_request: false, use_experimental_streamable_shell_tool: false, use_experimental_unified_exec_tool: false, - experimental_use_rmcp_client: false, + use_experimental_use_rmcp_client: false, include_view_image_tool: true, active_profile: Some("gpt5".to_string()), disable_paste_burst: false, From adb26689f133c1ff6e3e64d3381cb64b96ea556e Mon Sep 17 00:00:00 2001 From: Gabriel Peal Date: Thu, 25 Sep 2025 14:57:56 -0700 Subject: [PATCH 04/23] Doc --- codex-rs/core/src/config.rs | 2 ++ 1 file changed, 2 insertions(+) diff --git a/codex-rs/core/src/config.rs b/codex-rs/core/src/config.rs index dc5f43036f..5b5b60f8df 100644 --- a/codex-rs/core/src/config.rs +++ b/codex-rs/core/src/config.rs @@ -184,6 +184,8 @@ pub struct Config { /// If set to `true`, used only the experimental unified exec tool. pub use_experimental_unified_exec_tool: bool, + /// If set to `true`, use the experimental official Rust MCP client. + /// https://github.com/modelcontextprotocol/rust-sdk pub use_experimental_use_rmcp_client: bool, /// Include the `view_image` tool that lets the agent attach a local image path to context. From 879bd2a7709acfe39e7d9717055a0ecd4e745ba9 Mon Sep 17 00:00:00 2001 From: Gabriel Peal Date: Thu, 25 Sep 2025 15:39:40 -0700 Subject: [PATCH 05/23] --wip-- [skip ci] --- codex-rs/Cargo.lock | 130 ++++++++++++++++++++ codex-rs/cli/src/mcp_cmd.rs | 105 ++++++++++------ codex-rs/cli/tests/mcp_add_remove.rs | 4 +- codex-rs/core/src/config.rs | 59 ++++++--- codex-rs/core/src/config_types.rs | 111 ++++++++++++++++- codex-rs/core/src/mcp_connection_manager.rs | 88 +++++++++++-- codex-rs/rmcp-client/Cargo.toml | 7 ++ codex-rs/rmcp-client/src/rmcp_client.rs | 36 +++++- codex-rs/tui/src/history_cell.rs | 18 ++- 9 files changed, 475 insertions(+), 83 deletions(-) diff --git a/codex-rs/Cargo.lock b/codex-rs/Cargo.lock index 4a07146b44..511adf147e 100644 --- a/codex-rs/Cargo.lock +++ b/codex-rs/Cargo.lock @@ -933,8 +933,10 @@ name = "codex-rmcp-client" version = "0.0.0" dependencies = [ "anyhow", + "futures", "mcp-types", "pretty_assertions", + "reqwest", "rmcp", "serde", "serde_json", @@ -1999,8 +2001,10 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "335ff9f135e4384c8150d6f27c6daed433577f86b4750418338c01a1a2528592" dependencies = [ "cfg-if", + "js-sys", "libc", "wasi 0.11.1+wasi-snapshot-preview1", + "wasm-bindgen", ] [[package]] @@ -2010,9 +2014,11 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "26145e563e54f2cadc477553f1ec5ee650b00862f0a58bcd12cbdc5f0ea2d2f4" dependencies = [ "cfg-if", + "js-sys", "libc", "r-efi", "wasi 0.14.2+wasi-0.2.4", + "wasm-bindgen", ] [[package]] @@ -2209,6 +2215,7 @@ dependencies = [ "tokio", "tokio-rustls", "tower-service", + "webpki-roots", ] [[package]] @@ -2809,6 +2816,12 @@ dependencies = [ "hashbrown 0.15.4", ] +[[package]] +name = "lru-slab" +version = "0.1.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "112b39cec0b298b6c1999fee3e31427f74f676e4cb9879ed1a121b43661a4154" + [[package]] name = "lsp-types" version = "0.94.1" @@ -3571,6 +3584,61 @@ dependencies = [ "memchr", ] +[[package]] +name = "quinn" +version = "0.11.9" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "b9e20a958963c291dc322d98411f541009df2ced7b5a4f2bd52337638cfccf20" +dependencies = [ + "bytes", + "cfg_aliases 0.2.1", + "pin-project-lite", + "quinn-proto", + "quinn-udp", + "rustc-hash", + "rustls", + "socket2", + "thiserror 2.0.16", + "tokio", + "tracing", + "web-time", +] + +[[package]] +name = "quinn-proto" +version = "0.11.13" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "f1906b49b0c3bc04b5fe5d86a77925ae6524a19b816ae38ce1e426255f1d8a31" +dependencies = [ + "bytes", + "getrandom 0.3.3", + "lru-slab", + "rand", + "ring", + "rustc-hash", + "rustls", + "rustls-pki-types", + "slab", + "thiserror 2.0.16", + "tinyvec", + "tracing", + "web-time", +] + +[[package]] +name = "quinn-udp" +version = "0.5.14" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "addec6a0dcad8a8d96a771f815f0eaf55f9d1805756410b39f5fa81332574cbd" +dependencies = [ + "cfg_aliases 0.2.1", + "libc", + "once_cell", + "socket2", + "tracing", + "windows-sys 0.52.0", +] + [[package]] name = "quote" version = "1.0.40" @@ -3763,6 +3831,8 @@ dependencies = [ "native-tls", "percent-encoding", "pin-project-lite", + "quinn", + "rustls", "rustls-pki-types", "serde", "serde_json", @@ -3770,6 +3840,7 @@ dependencies = [ "sync_wrapper", "tokio", "tokio-native-tls", + "tokio-rustls", "tokio-util", "tower", "tower-http", @@ -3779,6 +3850,7 @@ dependencies = [ "wasm-bindgen-futures", "wasm-streams", "web-sys", + "webpki-roots", ] [[package]] @@ -3804,13 +3876,16 @@ dependencies = [ "base64", "chrono", "futures", + "http", "paste", "pin-project-lite", "process-wrap", + "reqwest", "rmcp-macros", "schemars 1.0.4", "serde", "serde_json", + "sse-stream", "thiserror 2.0.16", "tokio", "tokio-stream", @@ -3837,6 +3912,12 @@ version = "0.1.25" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "989e6739f80c4ad5b13e0fd7fe89531180375b18520cc8c82080e4dc4035b84f" +[[package]] +name = "rustc-hash" +version = "2.1.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "357703d41365b4b27c590e3ed91eabb1b663f07c4c084095e60cbed4362dff0d" + [[package]] name = "rustix" version = "0.38.44" @@ -3870,6 +3951,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "2491382039b29b9b11ff08b76ff6c97cf287671dbb74f0be44bda389fffe9bd1" dependencies = [ "once_cell", + "ring", "rustls-pki-types", "rustls-webpki", "subtle", @@ -3882,6 +3964,7 @@ version = "1.12.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "229a4a4c221013e7e1f1a043678c5cc39fe5171437c88fb47151a21e6f5b5c79" dependencies = [ + "web-time", "zeroize", ] @@ -4355,6 +4438,19 @@ dependencies = [ "windows-sys 0.59.0", ] +[[package]] +name = "sse-stream" +version = "0.2.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "eb4dc4d33c68ec1f27d386b5610a351922656e1fdf5c05bbaad930cd1519479a" +dependencies = [ + "bytes", + "futures-util", + "http-body", + "http-body-util", + "pin-project-lite", +] + [[package]] name = "stable_deref_trait" version = "1.2.0" @@ -4810,6 +4906,21 @@ dependencies = [ "zerovec", ] +[[package]] +name = "tinyvec" +version = "1.10.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "bfa5fdc3bce6191a1dbc8c02d5c8bffcf557bafa17c124c5264a458f1b0613fa" +dependencies = [ + "tinyvec_macros", +] + +[[package]] +name = "tinyvec_macros" +version = "0.1.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "1f3ccbac311fea05f86f61904b462b55fb3df8837a366dfc601a0161d0532f20" + [[package]] name = "tokio" version = "1.47.1" @@ -5422,6 +5533,16 @@ dependencies = [ "wasm-bindgen", ] +[[package]] +name = "web-time" +version = "1.1.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "5a6580f308b1fad9207618087a65c04e7a10bc77e02c8e84e9b00dd4b12fa0bb" +dependencies = [ + "js-sys", + "wasm-bindgen", +] + [[package]] name = "webbrowser" version = "1.0.5" @@ -5438,6 +5559,15 @@ dependencies = [ "web-sys", ] +[[package]] +name = "webpki-roots" +version = "1.0.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "7e8983c3ab33d6fb807cfcdad2491c4ea8cbc8ed839181c7dfd9c67c83e261b2" +dependencies = [ + "rustls-pki-types", +] + [[package]] name = "weezl" version = "0.1.10" diff --git a/codex-rs/cli/src/mcp_cmd.rs b/codex-rs/cli/src/mcp_cmd.rs index 465de71aac..c931935e0e 100644 --- a/codex-rs/cli/src/mcp_cmd.rs +++ b/codex-rs/cli/src/mcp_cmd.rs @@ -145,9 +145,11 @@ fn run_add(config_overrides: &CliConfigOverrides, add_args: AddArgs) -> Result<( .with_context(|| format!("failed to load MCP servers from {}", codex_home.display()))?; let new_entry = McpServerConfig { - command: command_bin, + command: Some(command_bin), args: command_args, env: env_map, + url: None, + bearer_token: None, startup_timeout_sec: None, tool_timeout_sec: None, }; @@ -211,6 +213,8 @@ fn run_list(config_overrides: &CliConfigOverrides, list_args: ListArgs) -> Resul "command": cfg.command, "args": cfg.args, "env": env, + "url": cfg.url, + "bearer_token": cfg.bearer_token, "startup_timeout_sec": cfg .startup_timeout_sec .map(|timeout| timeout.as_secs_f64()), @@ -232,27 +236,36 @@ fn run_list(config_overrides: &CliConfigOverrides, list_args: ListArgs) -> Resul let mut rows: Vec<[String; 4]> = Vec::new(); for (name, cfg) in entries { - let args = if cfg.args.is_empty() { - "-".to_string() - } else { - cfg.args.join(" ") - }; - - let env = match cfg.env.as_ref() { - None => "-".to_string(), - Some(map) if map.is_empty() => "-".to_string(), - Some(map) => { - let mut pairs: Vec<_> = map.iter().collect(); - pairs.sort_by(|(a, _), (b, _)| a.cmp(b)); - pairs - .into_iter() - .map(|(k, v)| format!("{k}={v}")) - .collect::>() - .join(", ") + let (command_display, args_display, env_display) = match (&cfg.command, &cfg.url) { + (Some(command), None) => { + let args = if cfg.args.is_empty() { + "-".to_string() + } else { + cfg.args.join(" ") + }; + let env_str = match cfg.env.as_ref() { + None => "-".to_string(), + Some(map) if map.is_empty() => "-".to_string(), + Some(map) => { + let mut pairs: Vec<_> = map.iter().collect(); + pairs.sort_by(|(a, _), (b, _)| a.cmp(b)); + pairs + .into_iter() + .map(|(k, v)| format!("{k}={v}")) + .collect::>() + .join(", ") + } + }; + (command.clone(), args, env_str) + } + (None, Some(url)) => { + let bearer = cfg.bearer_token.clone().unwrap_or_else(|| "-".to_string()); + (url.clone(), "-".to_string(), bearer) } + _ => ("-".to_string(), "-".to_string(), "-".to_string()), }; - rows.push([name.clone(), cfg.command.clone(), args, env]); + rows.push([name.clone(), command_display, args_display, env_display]); } let mut widths = ["Name".len(), "Command".len(), "Args".len(), "Env".len()]; @@ -311,6 +324,8 @@ fn run_get(config_overrides: &CliConfigOverrides, get_args: GetArgs) -> Result<( "command": server.command, "args": server.args, "env": env, + "url": server.url, + "bearer_token": server.bearer_token, "startup_timeout_sec": server .startup_timeout_sec .map(|timeout| timeout.as_secs_f64()), @@ -323,27 +338,39 @@ fn run_get(config_overrides: &CliConfigOverrides, get_args: GetArgs) -> Result<( } println!("{}", get_args.name); - println!(" command: {}", server.command); - let args = if server.args.is_empty() { - "-".to_string() - } else { - server.args.join(" ") - }; - println!(" args: {args}"); - let env_display = match server.env.as_ref() { - None => "-".to_string(), - Some(map) if map.is_empty() => "-".to_string(), - Some(map) => { - let mut pairs: Vec<_> = map.iter().collect(); - pairs.sort_by(|(a, _), (b, _)| a.cmp(b)); - pairs - .into_iter() - .map(|(k, v)| format!("{k}={v}")) - .collect::>() - .join(", ") + match (&server.command, &server.url) { + (Some(command), None) => { + println!(" command: {command}"); + let args = if server.args.is_empty() { + "-".to_string() + } else { + server.args.join(" ") + }; + println!(" args: {args}"); + let env_display = match server.env.as_ref() { + None => "-".to_string(), + Some(map) if map.is_empty() => "-".to_string(), + Some(map) => { + let mut pairs: Vec<_> = map.iter().collect(); + pairs.sort_by(|(a, _), (b, _)| a.cmp(b)); + pairs + .into_iter() + .map(|(k, v)| format!("{k}={v}")) + .collect::>() + .join(", ") + } + }; + println!(" env: {env_display}"); } - }; - println!(" env: {env_display}"); + (None, Some(url)) => { + println!(" url: {url}"); + let bearer = server.bearer_token.as_deref().unwrap_or("-"); + println!(" bearer_token: {bearer}"); + } + _ => { + println!(" command/url: -"); + } + } if let Some(timeout) = server.startup_timeout_sec { println!(" startup_timeout_sec: {}", timeout.as_secs_f64()); } diff --git a/codex-rs/cli/tests/mcp_add_remove.rs b/codex-rs/cli/tests/mcp_add_remove.rs index 9e54f0d867..6c328c0472 100644 --- a/codex-rs/cli/tests/mcp_add_remove.rs +++ b/codex-rs/cli/tests/mcp_add_remove.rs @@ -26,9 +26,11 @@ fn add_and_remove_server_updates_global_config() -> Result<()> { let servers = load_global_mcp_servers(codex_home.path())?; assert_eq!(servers.len(), 1); let docs = servers.get("docs").expect("server should exist"); - assert_eq!(docs.command, "echo"); + assert_eq!(docs.command.as_deref(), Some("echo")); assert_eq!(docs.args, vec!["hello".to_string()]); assert!(docs.env.is_none()); + assert!(docs.url.is_none()); + assert!(docs.bearer_token.is_none()); let mut remove_cmd = codex_command(codex_home.path())?; remove_cmd diff --git a/codex-rs/core/src/config.rs b/codex-rs/core/src/config.rs index 5b5b60f8df..e5df088a06 100644 --- a/codex-rs/core/src/config.rs +++ b/codex-rs/core/src/config.rs @@ -314,27 +314,44 @@ pub fn write_global_mcp_servers( for (name, config) in servers { let mut entry = TomlTable::new(); entry.set_implicit(false); - entry["command"] = toml_edit::value(config.command.clone()); + match (&config.command, &config.url) { + (Some(command), None) => { + entry["command"] = toml_edit::value(command.clone()); + + if !config.args.is_empty() { + let mut args = TomlArray::new(); + for arg in &config.args { + args.push(arg.clone()); + } + entry["args"] = TomlItem::Value(args.into()); + } - if !config.args.is_empty() { - let mut args = TomlArray::new(); - for arg in &config.args { - args.push(arg.clone()); + if let Some(env) = &config.env + && !env.is_empty() + { + let mut env_table = TomlTable::new(); + env_table.set_implicit(false); + let mut pairs: Vec<_> = env.iter().collect(); + pairs.sort_by(|(a, _), (b, _)| a.cmp(b)); + for (key, value) in pairs { + env_table.insert(key, toml_edit::value(value.clone())); + } + entry["env"] = TomlItem::Table(env_table); + } } - entry["args"] = TomlItem::Value(args.into()); - } - - if let Some(env) = &config.env - && !env.is_empty() - { - let mut env_table = TomlTable::new(); - env_table.set_implicit(false); - let mut pairs: Vec<_> = env.iter().collect(); - pairs.sort_by(|(a, _), (b, _)| a.cmp(b)); - for (key, value) in pairs { - env_table.insert(key, toml_edit::value(value.clone())); + (None, Some(url)) => { + entry["url"] = toml_edit::value(url.clone()); + if let Some(token) = &config.bearer_token { + entry["bearer_token"] = toml_edit::value(token.clone()); + } + } + _ => { + tracing::warn!( + "skipping MCP server `{}` with invalid transport configuration", + name + ); + continue; } - entry["env"] = TomlItem::Table(env_table); } if let Some(timeout) = config.startup_timeout_sec { @@ -1294,9 +1311,11 @@ exclude_slash_tmp = true servers.insert( "docs".to_string(), McpServerConfig { - command: "echo".to_string(), + command: Some("echo".to_string()), args: vec!["hello".to_string()], env: None, + url: None, + bearer_token: None, startup_timeout_sec: Some(Duration::from_secs(3)), tool_timeout_sec: Some(Duration::from_secs(5)), }, @@ -1307,7 +1326,7 @@ exclude_slash_tmp = true let loaded = load_global_mcp_servers(codex_home.path())?; assert_eq!(loaded.len(), 1); let docs = loaded.get("docs").expect("docs entry"); - assert_eq!(docs.command, "echo"); + assert_eq!(docs.command, Some("echo".to_string())); assert_eq!(docs.args, vec!["hello".to_string()]); assert_eq!(docs.startup_timeout_sec, Some(Duration::from_secs(3))); assert_eq!(docs.tool_timeout_sec, Some(Duration::from_secs(5))); diff --git a/codex-rs/core/src/config_types.rs b/codex-rs/core/src/config_types.rs index d273b23d69..51c28505e3 100644 --- a/codex-rs/core/src/config_types.rs +++ b/codex-rs/core/src/config_types.rs @@ -15,7 +15,8 @@ use serde::de::Error as SerdeError; #[derive(Serialize, Debug, Clone, PartialEq)] pub struct McpServerConfig { - pub command: String, + #[serde(default, skip_serializing_if = "Option::is_none")] + pub command: Option, #[serde(default)] pub args: Vec, @@ -23,6 +24,12 @@ pub struct McpServerConfig { #[serde(default)] pub env: Option>, + #[serde(default, skip_serializing_if = "Option::is_none")] + pub url: Option, + + #[serde(default, skip_serializing_if = "Option::is_none")] + pub bearer_token: Option, + /// Startup timeout in seconds for initializing MCP server & initially listing tools. #[serde( default, @@ -43,12 +50,17 @@ impl<'de> Deserialize<'de> for McpServerConfig { { #[derive(Deserialize)] struct RawMcpServerConfig { - command: String, + #[serde(default)] + command: Option, #[serde(default)] args: Vec, #[serde(default)] env: Option>, #[serde(default)] + url: Option, + #[serde(default)] + bearer_token: Option, + #[serde(default)] startup_timeout_sec: Option, #[serde(default)] startup_timeout_ms: Option, @@ -67,16 +79,109 @@ impl<'de> Deserialize<'de> for McpServerConfig { (None, None) => None, }; + let command = raw.command.and_then(normalize_string_option); + let url = raw.url.and_then(normalize_string_option); + + let has_command = command.is_some(); + let has_url = url.is_some(); + + if has_command && has_url { + return Err(SerdeError::custom( + "MCP server config must not set both `command` and `url`", + )); + } + + if !has_command && !has_url { + return Err(SerdeError::custom( + "MCP server config must set either `command` or `url`", + )); + } + + if has_url { + if !raw.args.is_empty() { + return Err(SerdeError::custom( + "`args` is not supported when configuring MCP servers via `url`", + )); + } + if raw.env.as_ref().is_some_and(|env| !env.is_empty()) { + return Err(SerdeError::custom( + "`env` is not supported when configuring MCP servers via `url`", + )); + } + } + Ok(Self { - command: raw.command, + command, args: raw.args, env: raw.env, + url, + bearer_token: raw.bearer_token.and_then(normalize_string_option), startup_timeout_sec, tool_timeout_sec: raw.tool_timeout_sec, }) } } +fn normalize_string_option(value: String) -> Option { + let trimmed = value.trim(); + if trimmed.is_empty() { + None + } else { + Some(trimmed.to_string()) + } +} + +#[cfg(test)] +mod tests { + use super::*; + use pretty_assertions::assert_eq; + + #[test] + fn deserialize_command_server_config() { + let cfg: McpServerConfig = toml::from_str( + r#" + command = "echo" + args = ["hello", "world"] + "#, + ) + .expect("should deserialize command config"); + + assert_eq!(cfg.command.as_deref(), Some("echo")); + assert_eq!(cfg.args, vec!["hello", "world"]); + assert!(cfg.url.is_none()); + } + + #[test] + fn deserialize_streamable_http_server_config() { + let cfg: McpServerConfig = toml::from_str( + r#" + url = "https://example.com/mcp" + bearer_token = "secret" + "#, + ) + .expect("should deserialize http config"); + + assert_eq!(cfg.url.as_deref(), Some("https://example.com/mcp")); + assert_eq!(cfg.bearer_token.as_deref(), Some("secret")); + assert!(cfg.command.is_none()); + assert!(cfg.args.is_empty()); + assert!(cfg.env.is_none()); + } + + #[test] + fn deserialize_rejects_invalid_transport_combo() { + let err = toml::from_str::( + r#" + command = "echo" + url = "https://example.com" + "#, + ) + .expect_err("should reject command+url"); + + assert!(err.to_string().contains("must not set both")); + } +} + mod option_duration_secs { use serde::Deserialize; use serde::Deserializer; diff --git a/codex-rs/core/src/mcp_connection_manager.rs b/codex-rs/core/src/mcp_connection_manager.rs index 5ec4f05757..21f33d35b1 100644 --- a/codex-rs/core/src/mcp_connection_manager.rs +++ b/codex-rs/core/src/mcp_connection_manager.rs @@ -118,6 +118,17 @@ impl McpClientAdapter { } } + async fn new_streamable_http_client( + url: String, + bearer_token: Option, + params: mcp_types::InitializeRequestParams, + startup_timeout: Duration, + ) -> Result { + let client = Arc::new(RmcpClient::new_streamable_http_client(url, bearer_token)?); + client.initialize(params, Some(startup_timeout)).await?; + Ok(McpClientAdapter::Rmcp(client)) + } + async fn list_tools( &self, params: Option, @@ -188,16 +199,52 @@ impl McpConnectionManager { continue; } + let has_command = cfg.command.is_some(); + let has_url = cfg.url.is_some(); + + if has_command && has_url { + errors.insert( + server_name.clone(), + anyhow!( + "MCP server `{}` must not set both `command` and `url`", + server_name + ), + ); + continue; + } + + if !has_command && !has_url { + errors.insert( + server_name.clone(), + anyhow!( + "MCP server `{}` must set either `command` or `url`", + server_name + ), + ); + continue; + } + + if cfg.url.is_some() && !use_rmcp_client { + info!( + "skipping MCP server `{}` configured with url because rmcp client is disabled", + server_name + ); + continue; + } + let startup_timeout = cfg.startup_timeout_sec.unwrap_or(DEFAULT_STARTUP_TIMEOUT); let tool_timeout = cfg.tool_timeout_sec.unwrap_or(DEFAULT_TOOL_TIMEOUT); let use_rmcp_client_flag = use_rmcp_client; join_set.spawn(async move { let McpServerConfig { - command, args, env, .. + command, + args, + env, + url, + bearer_token, + .. } = cfg; - let command_os: OsString = command.into(); - let args_os: Vec = args.into_iter().map(Into::into).collect(); let params = mcp_types::InitializeRequestParams { capabilities: ClientCapabilities { experimental: None, @@ -219,16 +266,31 @@ impl McpConnectionManager { protocol_version: mcp_types::MCP_SCHEMA_VERSION.to_owned(), }; - let client = McpClientAdapter::new_stdio_client( - use_rmcp_client_flag, - command_os, - args_os, - env, - params, - startup_timeout, - ) - .await - .map(|c| (c, startup_timeout)); + let client = match (command, url) { + (Some(command), None) => { + let command_os: OsString = command.into(); + let args_os: Vec = args.into_iter().map(Into::into).collect(); + McpClientAdapter::new_stdio_client( + use_rmcp_client_flag, + command_os, + args_os, + env, + params.clone(), + startup_timeout, + ) + .await + .map(|c| (c, startup_timeout)) + } + (None, Some(url)) => McpClientAdapter::new_streamable_http_client( + url, + bearer_token, + params, + startup_timeout, + ) + .await + .map(|c| (c, startup_timeout)), + _ => Err(anyhow!("invalid MCP server transport configuration")), + }; ((server_name, tool_timeout), client) }); diff --git a/codex-rs/rmcp-client/Cargo.toml b/codex-rs/rmcp-client/Cargo.toml index 52126330e8..e139555eb8 100644 --- a/codex-rs/rmcp-client/Cargo.toml +++ b/codex-rs/rmcp-client/Cargo.toml @@ -16,6 +16,13 @@ rmcp = { version = "0.7.0", default-features = false, features = [ "schemars", "server", "transport-child-process", + "transport-streamable-http-client-reqwest", +] } +futures = { version = "0.3", default-features = false, features = ["std"] } +reqwest = { version = "0.12", default-features = false, features = [ + "json", + "stream", + "rustls-tls", ] } serde = { version = "1", features = ["derive"] } serde_json = "1" diff --git a/codex-rs/rmcp-client/src/rmcp_client.rs b/codex-rs/rmcp-client/src/rmcp_client.rs index 472aa21e6b..125e4079e7 100644 --- a/codex-rs/rmcp-client/src/rmcp_client.rs +++ b/codex-rs/rmcp-client/src/rmcp_client.rs @@ -7,6 +7,7 @@ use std::time::Duration; use anyhow::Result; use anyhow::anyhow; +use futures::FutureExt; use mcp_types::CallToolRequestParams; use mcp_types::CallToolResult; use mcp_types::InitializeRequestParams; @@ -19,7 +20,9 @@ use rmcp::model::PaginatedRequestParam; use rmcp::service::RoleClient; use rmcp::service::RunningService; use rmcp::service::{self}; +use rmcp::transport::StreamableHttpClientTransport; use rmcp::transport::child_process::TokioChildProcess; +use rmcp::transport::streamable_http_client::StreamableHttpClientTransportConfig; use tokio::io::AsyncBufReadExt; use tokio::io::BufReader; use tokio::process::Command; @@ -35,9 +38,14 @@ use crate::utils::convert_to_rmcp; use crate::utils::create_env_for_mcp_server; use crate::utils::run_with_timeout; +enum PendingTransport { + ChildProcess(TokioChildProcess), + StreamableHttp(StreamableHttpClientTransport), +} + enum ClientState { Connecting { - transport: Option, + transport: Option, }, Ready { service: Arc>, @@ -90,7 +98,22 @@ impl RmcpClient { Ok(Self { state: Mutex::new(ClientState::Connecting { - transport: Some(transport), + transport: Some(PendingTransport::ChildProcess(transport)), + }), + }) + } + + pub fn new_streamable_http_client(url: String, bearer_token: Option) -> Result { + let mut config = StreamableHttpClientTransportConfig::with_uri(url); + if let Some(token) = bearer_token { + config = config.auth_header(token); + } + + let transport = StreamableHttpClientTransport::from_config(config); + + Ok(Self { + state: Mutex::new(ClientState::Connecting { + transport: Some(PendingTransport::StreamableHttp(transport)), }), }) } @@ -116,7 +139,14 @@ impl RmcpClient { let client_info = convert_to_rmcp::<_, InitializeRequestParam>(params.clone())?; let client_handler = LoggingClientHandler::new(client_info); - let service_future = service::serve_client(client_handler, transport); + let service_future = match transport { + PendingTransport::ChildProcess(transport) => { + service::serve_client(client_handler.clone(), transport).boxed() + } + PendingTransport::StreamableHttp(transport) => { + service::serve_client(client_handler, transport).boxed() + } + }; let service = match timeout { Some(duration) => time::timeout(duration, service_future) diff --git a/codex-rs/tui/src/history_cell.rs b/codex-rs/tui/src/history_cell.rs index 3ae0d177b4..58e9547234 100644 --- a/codex-rs/tui/src/history_cell.rs +++ b/codex-rs/tui/src/history_cell.rs @@ -1232,10 +1232,20 @@ pub(crate) fn new_mcp_tools_output( lines.push(vec![" • Server: ".into(), server.clone().into()].into()); - if !cfg.command.is_empty() { - let cmd_display = format!("{} {}", cfg.command, cfg.args.join(" ")); - - lines.push(vec![" • Command: ".into(), cmd_display.into()].into()); + match (&cfg.command, &cfg.url) { + (Some(command), None) => { + let args = if cfg.args.is_empty() { + String::new() + } else { + format!(" {}", cfg.args.join(" ")) + }; + let cmd_display = format!("{command}{args}"); + lines.push(vec![" • Command: ".into(), cmd_display.into()].into()); + } + (None, Some(url)) => { + lines.push(vec![" • URL: ".into(), url.clone().into()].into()); + } + _ => {} } if names.is_empty() { From b07c8dcab8285b14c43d279190aa348d318b5ac9 Mon Sep 17 00:00:00 2001 From: Gabriel Peal Date: Thu, 25 Sep 2025 16:13:41 -0700 Subject: [PATCH 06/23] Added an e2e test --- codex-rs/core/tests/suite/mod.rs | 1 + codex-rs/core/tests/suite/rmcp_client.rs | 138 +++++++++++++++++ .../rmcp-client/src/bin/rmcp_test_server.rs | 143 ++++++++++++++++++ 3 files changed, 282 insertions(+) create mode 100644 codex-rs/core/tests/suite/rmcp_client.rs create mode 100644 codex-rs/rmcp-client/src/bin/rmcp_test_server.rs diff --git a/codex-rs/core/tests/suite/mod.rs b/codex-rs/core/tests/suite/mod.rs index 2d91e330a8..dfff260160 100644 --- a/codex-rs/core/tests/suite/mod.rs +++ b/codex-rs/core/tests/suite/mod.rs @@ -12,6 +12,7 @@ mod live_cli; mod model_overrides; mod prompt_caching; mod review; +mod rmcp_client; mod rollout_list_find; mod seatbelt; mod stream_error_allows_next_turn; diff --git a/codex-rs/core/tests/suite/rmcp_client.rs b/codex-rs/core/tests/suite/rmcp_client.rs new file mode 100644 index 0000000000..433ed739d3 --- /dev/null +++ b/codex-rs/core/tests/suite/rmcp_client.rs @@ -0,0 +1,138 @@ +use std::collections::HashMap; + +use assert_cmd::cargo::cargo_bin; +use codex_core::protocol::AskForApproval; +use codex_core::protocol::EventMsg; +use codex_core::protocol::InputItem; +use codex_core::protocol::Op; +use codex_core::protocol::SandboxPolicy; +use codex_protocol::config_types::ReasoningSummary; +use core_test_support::non_sandbox_test; +use core_test_support::responses; +use core_test_support::test_codex::test_codex; +use core_test_support::wait_for_event; +use serde_json::Value; +use wiremock::Mock; +use wiremock::ResponseTemplate; +use wiremock::matchers::method; +use wiremock::matchers::path; + +#[tokio::test(flavor = "multi_thread", worker_threads = 2)] +async fn rmcp_tool_call_round_trip() -> anyhow::Result<()> { + non_sandbox_test!(result); + + let server = responses::start_mock_server().await; + + let call_id = "call-123"; + let server_name = "rmcp"; + let tool_name = format!("{server_name}__echo"); + + let sse_body = responses::sse(vec![ + serde_json::json!({ + "type": "response.created", + "response": {"id": "resp-1"} + }), + responses::ev_function_call(call_id, &tool_name, "{\"message\":\"ping\"}"), + responses::ev_completed("resp-1"), + ]); + + Mock::given(method("POST")) + .and(path("/v1/responses")) + .respond_with(ResponseTemplate::new(200).set_body_raw(sse_body, "text/event-stream")) + .expect(1) + .mount(&server) + .await; + + let test_server_bin = cargo_bin("rmcp_test_server"); + let expected_env_value = "propagated-env"; + + let fixture = test_codex() + .with_config(move |config| { + use codex_core::config_types::McpServerConfig; + config.use_experimental_use_rmcp_client = true; + config.mcp_servers.insert( + server_name.to_string(), + McpServerConfig { + command: test_server_bin.to_string_lossy().into_owned(), + args: Vec::new(), + env: Some(HashMap::from([( + "MCP_TEST_VALUE".to_string(), + expected_env_value.to_string(), + )])), + startup_timeout_sec: None, + tool_timeout_sec: None, + }, + ); + }) + .build(&server) + .await?; + let session_model = fixture.session_configured.model.clone(); + + fixture + .codex + .submit(Op::UserTurn { + items: vec![InputItem::Text { + text: "call the rmcp echo tool".into(), + }], + final_output_json_schema: None, + cwd: fixture.cwd.path().to_path_buf(), + approval_policy: AskForApproval::Never, + sandbox_policy: SandboxPolicy::DangerFullAccess, + model: session_model, + effort: None, + summary: ReasoningSummary::Auto, + }) + .await?; + + let begin_event = wait_for_event(&fixture.codex, |ev| { + matches!(ev, EventMsg::McpToolCallBegin(_)) + }) + .await; + let EventMsg::McpToolCallBegin(begin) = begin_event else { + unreachable!("event guard guarantees McpToolCallBegin"); + }; + assert_eq!(begin.invocation.server, server_name); + assert_eq!(begin.invocation.tool, "echo"); + + let end_event = wait_for_event(&fixture.codex, |ev| { + matches!(ev, EventMsg::McpToolCallEnd(_)) + }) + .await; + let EventMsg::McpToolCallEnd(end) = end_event else { + unreachable!("event guard guarantees McpToolCallEnd"); + }; + + let result = end + .result + .as_ref() + .expect("rmcp echo tool should return success"); + assert_eq!(result.is_error, Some(false)); + assert!( + result.content.is_empty(), + "content should default to an empty array" + ); + + let structured = result + .structured_content + .as_ref() + .expect("structured content"); + let Value::Object(map) = structured else { + panic!("structured content should be an object: {structured:?}"); + }; + let echo_value = map + .get("echo") + .and_then(Value::as_str) + .expect("echo payload present"); + assert_eq!(echo_value, "ping"); + let env_value = map + .get("env") + .and_then(Value::as_str) + .expect("env snapshot inserted"); + assert_eq!(env_value, expected_env_value); + + wait_for_event(&fixture.codex, |ev| matches!(ev, EventMsg::TaskComplete(_))).await; + + server.verify().await; + + Ok(()) +} diff --git a/codex-rs/rmcp-client/src/bin/rmcp_test_server.rs b/codex-rs/rmcp-client/src/bin/rmcp_test_server.rs new file mode 100644 index 0000000000..2a3a9c3fc2 --- /dev/null +++ b/codex-rs/rmcp-client/src/bin/rmcp_test_server.rs @@ -0,0 +1,143 @@ +use std::borrow::Cow; +use std::collections::HashMap; +use std::sync::Arc; + +use rmcp::ErrorData as McpError; +use rmcp::ServiceExt; +use rmcp::handler::server::ServerHandler; +use rmcp::model::CallToolRequestParam; +use rmcp::model::CallToolResult; +use rmcp::model::JsonObject; +use rmcp::model::ListToolsResult; +use rmcp::model::PaginatedRequestParam; +use rmcp::model::ServerCapabilities; +use rmcp::model::ServerInfo; +use rmcp::model::Tool; +use rmcp::transport::stdio; +use serde::Deserialize; +use serde_json::json; +use tokio::task; + +#[derive(Clone)] +struct TestToolServer { + tools: Arc>, +} + +impl TestToolServer { + fn new() -> Self { + let tools = vec![Self::echo_tool()]; + Self { + tools: Arc::new(tools), + } + } + + fn echo_tool() -> Tool { + let schema: JsonObject = serde_json::from_value(json!({ + "type": "object", + "properties": { + "message": { "type": "string" }, + "env_var": { "type": "string" } + }, + "required": ["message"], + "additionalProperties": false + })) + .expect("echo tool schema should deserialize"); + + Tool::new( + Cow::Borrowed("echo"), + Cow::Borrowed("Echo back the provided message and include environment data."), + Arc::new(schema), + ) + } +} + +#[derive(Deserialize)] +struct EchoArgs { + message: String, + #[allow(dead_code)] + env_var: Option, +} + +impl ServerHandler for TestToolServer { + fn get_info(&self) -> ServerInfo { + ServerInfo { + name: Cow::Borrowed("codex-rmcp-test-server"), + version: Some(Cow::Borrowed("0.1.0")), + capabilities: ServerCapabilities::builder() + .enable_tools() + .enable_tool_list_changed() + .build(), + ..ServerInfo::default() + } + } + + fn list_tools( + &self, + _request: Option, + _context: rmcp::service::RequestContext, + ) -> impl std::future::Future> + Send + '_ { + let tools = self.tools.clone(); + async move { + Ok(ListToolsResult { + tools: (*tools).clone(), + next_cursor: None, + }) + } + } + + fn call_tool( + &self, + request: CallToolRequestParam, + _context: rmcp::service::RequestContext, + ) -> impl std::future::Future> + Send + '_ { + async move { + match request.name.as_ref() { + "echo" => { + let args: EchoArgs = match request.arguments { + Some(arguments) => serde_json::from_value(serde_json::Value::Object( + arguments.into_iter().collect(), + )) + .map_err(|err| McpError::invalid_params(err.to_string(), None))?, + None => { + return Err(McpError::invalid_params( + "missing arguments for echo tool", + None, + )); + } + }; + + let env_snapshot: HashMap = std::env::vars().collect(); + let structured_content = json!({ + "echo": args.message, + "env": env_snapshot.get("MCP_TEST_VALUE"), + }); + + Ok(CallToolResult { + content: Vec::new(), + structured_content: Some(structured_content), + is_error: Some(false), + meta: None, + }) + } + other => Err(McpError::invalid_params( + format!("unknown tool: {other}"), + None, + )), + } + } + } +} + +#[tokio::main] +async fn main() -> Result<(), Box> { + // Run the server with STDIO transport. If the client disconnects we simply + // bubble up the error so the process exits. + let service = TestToolServer::new(); + let running = service.serve(stdio()).await?; + + // Wait for the client to finish interacting with the server. + running.waiting().await?; + // Drain background tasks to ensure clean shutdown. + task::yield_now().await; + Ok(()) +} From b75703331bfe019fcae3cff6d26c34a49281fa5c Mon Sep 17 00:00:00 2001 From: Gabriel Peal Date: Thu, 25 Sep 2025 17:51:20 -0700 Subject: [PATCH 07/23] WIP integration test --- codex-rs/core/tests/common/lib.rs | 1 + codex-rs/core/tests/suite/rmcp_client.rs | 18 +++++++++++------- 2 files changed, 12 insertions(+), 7 deletions(-) diff --git a/codex-rs/core/tests/common/lib.rs b/codex-rs/core/tests/common/lib.rs index 0fdd60387b..83e8a017b7 100644 --- a/codex-rs/core/tests/common/lib.rs +++ b/codex-rs/core/tests/common/lib.rs @@ -122,6 +122,7 @@ where .await .expect("timeout waiting for event") .expect("stream ended unexpectedly"); + eprintln!("found event: {ev:?}"); if predicate(&ev.msg) { return ev.msg; } diff --git a/codex-rs/core/tests/suite/rmcp_client.rs b/codex-rs/core/tests/suite/rmcp_client.rs index 433ed739d3..5ecfd53fdd 100644 --- a/codex-rs/core/tests/suite/rmcp_client.rs +++ b/codex-rs/core/tests/suite/rmcp_client.rs @@ -9,11 +9,13 @@ use codex_core::protocol::SandboxPolicy; use codex_protocol::config_types::ReasoningSummary; use core_test_support::non_sandbox_test; use core_test_support::responses; +use core_test_support::responses::mount_sse_once; use core_test_support::test_codex::test_codex; use core_test_support::wait_for_event; use serde_json::Value; use wiremock::Mock; use wiremock::ResponseTemplate; +use wiremock::matchers::any; use wiremock::matchers::method; use wiremock::matchers::path; @@ -27,7 +29,7 @@ async fn rmcp_tool_call_round_trip() -> anyhow::Result<()> { let server_name = "rmcp"; let tool_name = format!("{server_name}__echo"); - let sse_body = responses::sse(vec![ + let sse_body_1 = responses::sse(vec![ serde_json::json!({ "type": "response.created", "response": {"id": "resp-1"} @@ -36,12 +38,13 @@ async fn rmcp_tool_call_round_trip() -> anyhow::Result<()> { responses::ev_completed("resp-1"), ]); - Mock::given(method("POST")) - .and(path("/v1/responses")) - .respond_with(ResponseTemplate::new(200).set_body_raw(sse_body, "text/event-stream")) - .expect(1) - .mount(&server) - .await; + let sse_body_2 = responses::sse(vec![ + responses::ev_assistant_message("msg-1", "rmcp echo tool completed successfully."), + responses::ev_completed("resp-2"), + ]); + + mount_sse_once(&server, any(), sse_body_1).await; + mount_sse_once(&server, any(), sse_body_2).await; let test_server_bin = cargo_bin("rmcp_test_server"); let expected_env_value = "propagated-env"; @@ -130,6 +133,7 @@ async fn rmcp_tool_call_round_trip() -> anyhow::Result<()> { .expect("env snapshot inserted"); assert_eq!(env_value, expected_env_value); + eprintln!("waiting for task complete"); wait_for_event(&fixture.codex, |ev| matches!(ev, EventMsg::TaskComplete(_))).await; server.verify().await; From 82a533051974a7a24ee3e394f4d33f536209c729 Mon Sep 17 00:00:00 2001 From: pakrym-oai Date: Thu, 25 Sep 2025 18:07:29 -0700 Subject: [PATCH 08/23] test: limit mock sse response to once and add debug logs Limit mock SSE response to a single use in tests; add debug eprintlns to rmcp_client tests. Fix order of edition in Cargo.toml and enable tokio io-std feature. Move stdio helper into rmcp_test_server, and comment out ServerInfo name/version fields. --- codex-rs/core/tests/common/responses.rs | 1 + codex-rs/core/tests/suite/rmcp_client.rs | 2 ++ codex-rs/rmcp-client/Cargo.toml | 3 ++- codex-rs/rmcp-client/src/bin/rmcp_test_server.rs | 9 +++++---- 4 files changed, 10 insertions(+), 5 deletions(-) diff --git a/codex-rs/core/tests/common/responses.rs b/codex-rs/core/tests/common/responses.rs index 2f55f17a52..7028b4a186 100644 --- a/codex-rs/core/tests/common/responses.rs +++ b/codex-rs/core/tests/common/responses.rs @@ -121,6 +121,7 @@ where .and(path("/v1/responses")) .and(matcher) .respond_with(sse_response(body)) + .up_to_n_times(1) .mount(server) .await; } diff --git a/codex-rs/core/tests/suite/rmcp_client.rs b/codex-rs/core/tests/suite/rmcp_client.rs index 5ecfd53fdd..9401cc1504 100644 --- a/codex-rs/core/tests/suite/rmcp_client.rs +++ b/codex-rs/core/tests/suite/rmcp_client.rs @@ -23,6 +23,7 @@ use wiremock::matchers::path; async fn rmcp_tool_call_round_trip() -> anyhow::Result<()> { non_sandbox_test!(result); + eprintln!("waiting for task complete"); let server = responses::start_mock_server().await; let call_id = "call-123"; @@ -46,6 +47,7 @@ async fn rmcp_tool_call_round_trip() -> anyhow::Result<()> { mount_sse_once(&server, any(), sse_body_1).await; mount_sse_once(&server, any(), sse_body_2).await; + eprintln!("waiting for task complete"); let test_server_bin = cargo_bin("rmcp_test_server"); let expected_env_value = "propagated-env"; diff --git a/codex-rs/rmcp-client/Cargo.toml b/codex-rs/rmcp-client/Cargo.toml index 52126330e8..da9989e531 100644 --- a/codex-rs/rmcp-client/Cargo.toml +++ b/codex-rs/rmcp-client/Cargo.toml @@ -1,7 +1,7 @@ [package] +edition = "2024" name = "codex-rmcp-client" version = { workspace = true } -edition = "2024" [lints] workspace = true @@ -25,6 +25,7 @@ tokio = { version = "1", features = [ "process", "rt-multi-thread", "sync", + "io-std", "time", ] } tracing = { version = "0.1.41", features = ["log"] } diff --git a/codex-rs/rmcp-client/src/bin/rmcp_test_server.rs b/codex-rs/rmcp-client/src/bin/rmcp_test_server.rs index 2a3a9c3fc2..21e0b38424 100644 --- a/codex-rs/rmcp-client/src/bin/rmcp_test_server.rs +++ b/codex-rs/rmcp-client/src/bin/rmcp_test_server.rs @@ -13,7 +13,6 @@ use rmcp::model::PaginatedRequestParam; use rmcp::model::ServerCapabilities; use rmcp::model::ServerInfo; use rmcp::model::Tool; -use rmcp::transport::stdio; use serde::Deserialize; use serde_json::json; use tokio::task; @@ -22,7 +21,9 @@ use tokio::task; struct TestToolServer { tools: Arc>, } - +pub fn stdio() -> (tokio::io::Stdin, tokio::io::Stdout) { + (tokio::io::stdin(), tokio::io::stdout()) +} impl TestToolServer { fn new() -> Self { let tools = vec![Self::echo_tool()]; @@ -61,8 +62,8 @@ struct EchoArgs { impl ServerHandler for TestToolServer { fn get_info(&self) -> ServerInfo { ServerInfo { - name: Cow::Borrowed("codex-rmcp-test-server"), - version: Some(Cow::Borrowed("0.1.0")), + // name: Cow::Borrowed("codex-rmcp-test-server"), + // version: Some(Cow::Borrowed("0.1.0")), capabilities: ServerCapabilities::builder() .enable_tools() .enable_tool_list_changed() From ab7a231e896d99636815a024650cc4669bb7eb7b Mon Sep 17 00:00:00 2001 From: Gabriel Peal Date: Thu, 25 Sep 2025 18:14:20 -0700 Subject: [PATCH 09/23] Cleanup --- codex-rs/core/tests/common/lib.rs | 1 - codex-rs/core/tests/suite/rmcp_client.rs | 48 +++++++------- .../rmcp-client/src/bin/rmcp_test_server.rs | 65 +++++++++---------- 3 files changed, 55 insertions(+), 59 deletions(-) diff --git a/codex-rs/core/tests/common/lib.rs b/codex-rs/core/tests/common/lib.rs index 83e8a017b7..0fdd60387b 100644 --- a/codex-rs/core/tests/common/lib.rs +++ b/codex-rs/core/tests/common/lib.rs @@ -122,7 +122,6 @@ where .await .expect("timeout waiting for event") .expect("stream ended unexpectedly"); - eprintln!("found event: {ev:?}"); if predicate(&ev.msg) { return ev.msg; } diff --git a/codex-rs/core/tests/suite/rmcp_client.rs b/codex-rs/core/tests/suite/rmcp_client.rs index 9401cc1504..20c966baac 100644 --- a/codex-rs/core/tests/suite/rmcp_client.rs +++ b/codex-rs/core/tests/suite/rmcp_client.rs @@ -13,42 +13,41 @@ use core_test_support::responses::mount_sse_once; use core_test_support::test_codex::test_codex; use core_test_support::wait_for_event; use serde_json::Value; -use wiremock::Mock; -use wiremock::ResponseTemplate; use wiremock::matchers::any; -use wiremock::matchers::method; -use wiremock::matchers::path; #[tokio::test(flavor = "multi_thread", worker_threads = 2)] async fn rmcp_tool_call_round_trip() -> anyhow::Result<()> { non_sandbox_test!(result); - eprintln!("waiting for task complete"); let server = responses::start_mock_server().await; let call_id = "call-123"; let server_name = "rmcp"; let tool_name = format!("{server_name}__echo"); - let sse_body_1 = responses::sse(vec![ - serde_json::json!({ - "type": "response.created", - "response": {"id": "resp-1"} - }), - responses::ev_function_call(call_id, &tool_name, "{\"message\":\"ping\"}"), - responses::ev_completed("resp-1"), - ]); - - let sse_body_2 = responses::sse(vec![ - responses::ev_assistant_message("msg-1", "rmcp echo tool completed successfully."), - responses::ev_completed("resp-2"), - ]); - - mount_sse_once(&server, any(), sse_body_1).await; - mount_sse_once(&server, any(), sse_body_2).await; + mount_sse_once( + &server, + any(), + responses::sse(vec![ + serde_json::json!({ + "type": "response.created", + "response": {"id": "resp-1"} + }), + responses::ev_function_call(call_id, &tool_name, "{\"message\":\"ping\"}"), + responses::ev_completed("resp-1"), + ]), + ) + .await; + mount_sse_once( + &server, + any(), + responses::sse(vec![ + responses::ev_assistant_message("msg-1", "rmcp echo tool completed successfully."), + responses::ev_completed("resp-2"), + ]), + ) + .await; - eprintln!("waiting for task complete"); - let test_server_bin = cargo_bin("rmcp_test_server"); let expected_env_value = "propagated-env"; let fixture = test_codex() @@ -58,7 +57,7 @@ async fn rmcp_tool_call_round_trip() -> anyhow::Result<()> { config.mcp_servers.insert( server_name.to_string(), McpServerConfig { - command: test_server_bin.to_string_lossy().into_owned(), + command: cargo_bin("rmcp_test_server").to_string_lossy().into_owned(), args: Vec::new(), env: Some(HashMap::from([( "MCP_TEST_VALUE".to_string(), @@ -135,7 +134,6 @@ async fn rmcp_tool_call_round_trip() -> anyhow::Result<()> { .expect("env snapshot inserted"); assert_eq!(env_value, expected_env_value); - eprintln!("waiting for task complete"); wait_for_event(&fixture.codex, |ev| matches!(ev, EventMsg::TaskComplete(_))).await; server.verify().await; diff --git a/codex-rs/rmcp-client/src/bin/rmcp_test_server.rs b/codex-rs/rmcp-client/src/bin/rmcp_test_server.rs index 21e0b38424..5a14bb2648 100644 --- a/codex-rs/rmcp-client/src/bin/rmcp_test_server.rs +++ b/codex-rs/rmcp-client/src/bin/rmcp_test_server.rs @@ -33,6 +33,7 @@ impl TestToolServer { } fn echo_tool() -> Tool { + #[expect(clippy::expect_used)] let schema: JsonObject = serde_json::from_value(json!({ "type": "object", "properties": { @@ -86,45 +87,43 @@ impl ServerHandler for TestToolServer { } } - fn call_tool( + async fn call_tool( &self, request: CallToolRequestParam, _context: rmcp::service::RequestContext, - ) -> impl std::future::Future> + Send + '_ { - async move { - match request.name.as_ref() { - "echo" => { - let args: EchoArgs = match request.arguments { - Some(arguments) => serde_json::from_value(serde_json::Value::Object( - arguments.into_iter().collect(), - )) - .map_err(|err| McpError::invalid_params(err.to_string(), None))?, - None => { - return Err(McpError::invalid_params( - "missing arguments for echo tool", - None, - )); - } - }; + ) -> Result { + match request.name.as_ref() { + "echo" => { + let args: EchoArgs = match request.arguments { + Some(arguments) => serde_json::from_value(serde_json::Value::Object( + arguments.into_iter().collect(), + )) + .map_err(|err| McpError::invalid_params(err.to_string(), None))?, + None => { + return Err(McpError::invalid_params( + "missing arguments for echo tool", + None, + )); + } + }; - let env_snapshot: HashMap = std::env::vars().collect(); - let structured_content = json!({ - "echo": args.message, - "env": env_snapshot.get("MCP_TEST_VALUE"), - }); + let env_snapshot: HashMap = std::env::vars().collect(); + let structured_content = json!({ + "echo": args.message, + "env": env_snapshot.get("MCP_TEST_VALUE"), + }); - Ok(CallToolResult { - content: Vec::new(), - structured_content: Some(structured_content), - is_error: Some(false), - meta: None, - }) - } - other => Err(McpError::invalid_params( - format!("unknown tool: {other}"), - None, - )), + Ok(CallToolResult { + content: Vec::new(), + structured_content: Some(structured_content), + is_error: Some(false), + meta: None, + }) } + other => Err(McpError::invalid_params( + format!("unknown tool: {other}"), + None, + )), } } } From 728747510237f2f02e7c1b1728ce7f4bf40cd7dd Mon Sep 17 00:00:00 2001 From: Gabriel Peal Date: Thu, 25 Sep 2025 18:26:30 -0700 Subject: [PATCH 10/23] Fix sandbox check --- codex-rs/core/tests/suite/rmcp_client.rs | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/codex-rs/core/tests/suite/rmcp_client.rs b/codex-rs/core/tests/suite/rmcp_client.rs index 20c966baac..7ce6292b36 100644 --- a/codex-rs/core/tests/suite/rmcp_client.rs +++ b/codex-rs/core/tests/suite/rmcp_client.rs @@ -7,9 +7,9 @@ use codex_core::protocol::InputItem; use codex_core::protocol::Op; use codex_core::protocol::SandboxPolicy; use codex_protocol::config_types::ReasoningSummary; -use core_test_support::non_sandbox_test; use core_test_support::responses; use core_test_support::responses::mount_sse_once; +use core_test_support::skip_if_no_network; use core_test_support::test_codex::test_codex; use core_test_support::wait_for_event; use serde_json::Value; @@ -17,7 +17,7 @@ use wiremock::matchers::any; #[tokio::test(flavor = "multi_thread", worker_threads = 2)] async fn rmcp_tool_call_round_trip() -> anyhow::Result<()> { - non_sandbox_test!(result); + skip_if_no_network!(Ok(())); let server = responses::start_mock_server().await; From fcc117b9a92a573aee79f895cafe073c8f227212 Mon Sep 17 00:00:00 2001 From: Gabriel Peal Date: Thu, 25 Sep 2025 18:54:03 -0700 Subject: [PATCH 11/23] Add some logs --- codex-rs/core/tests/suite/rmcp_client.rs | 8 +++++++- 1 file changed, 7 insertions(+), 1 deletion(-) diff --git a/codex-rs/core/tests/suite/rmcp_client.rs b/codex-rs/core/tests/suite/rmcp_client.rs index 7ce6292b36..e5c1a9347f 100644 --- a/codex-rs/core/tests/suite/rmcp_client.rs +++ b/codex-rs/core/tests/suite/rmcp_client.rs @@ -88,10 +88,13 @@ async fn rmcp_tool_call_round_trip() -> anyhow::Result<()> { }) .await?; + eprintln!("waiting for begin event"); let begin_event = wait_for_event(&fixture.codex, |ev| { matches!(ev, EventMsg::McpToolCallBegin(_)) }) .await; + + eprintln!("begin_event: {begin_event:?}"); let EventMsg::McpToolCallBegin(begin) = begin_event else { unreachable!("event guard guarantees McpToolCallBegin"); }; @@ -102,6 +105,7 @@ async fn rmcp_tool_call_round_trip() -> anyhow::Result<()> { matches!(ev, EventMsg::McpToolCallEnd(_)) }) .await; + eprintln!("end_event: {end_event:?}"); let EventMsg::McpToolCallEnd(end) = end_event else { unreachable!("event guard guarantees McpToolCallEnd"); }; @@ -134,7 +138,9 @@ async fn rmcp_tool_call_round_trip() -> anyhow::Result<()> { .expect("env snapshot inserted"); assert_eq!(env_value, expected_env_value); - wait_for_event(&fixture.codex, |ev| matches!(ev, EventMsg::TaskComplete(_))).await; + let task_complete_event = + wait_for_event(&fixture.codex, |ev| matches!(ev, EventMsg::TaskComplete(_))).await; + eprintln!("task_complete_event: {task_complete_event:?}"); server.verify().await; From 35292f82bf81368b0558868d09574fa9ae300f16 Mon Sep 17 00:00:00 2001 From: Gabriel Peal Date: Thu, 25 Sep 2025 22:15:46 -0700 Subject: [PATCH 12/23] Logs for failing test --- codex-rs/Cargo.lock | 12 ++++++++ codex-rs/Cargo.toml | 1 + codex-rs/core/Cargo.toml | 1 + codex-rs/core/src/mcp_connection_manager.rs | 5 ++++ codex-rs/core/tests/suite/rmcp_client.rs | 30 ++++++++++++++----- .../rmcp-client/src/bin/rmcp_test_server.rs | 1 + 6 files changed, 42 insertions(+), 8 deletions(-) diff --git a/codex-rs/Cargo.lock b/codex-rs/Cargo.lock index 901bb1a7ca..edaa5e3812 100644 --- a/codex-rs/Cargo.lock +++ b/codex-rs/Cargo.lock @@ -690,6 +690,7 @@ dependencies = [ "core_test_support", "dirs", "env-flags", + "escargot", "eventsource-stream", "futures", "landlock", @@ -1698,6 +1699,17 @@ version = "3.3.2" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "dea2df4cf52843e0452895c455a1a2cfbb842a1e7329671acf418fdc53ed4c59" +[[package]] +name = "escargot" +version = "0.5.15" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "11c3aea32bc97b500c9ca6a72b768a26e558264303d101d3409cf6d57a9ed0cf" +dependencies = [ + "log", + "serde", + "serde_json", +] + [[package]] name = "event-listener" version = "5.4.0" diff --git a/codex-rs/Cargo.toml b/codex-rs/Cargo.toml index d4c10ff01c..a909690366 100644 --- a/codex-rs/Cargo.toml +++ b/codex-rs/Cargo.toml @@ -82,6 +82,7 @@ dotenvy = "0.15.7" env-flags = "0.1.1" env_logger = "0.11.5" eventsource-stream = "0.2.3" +escargot = "0.5" futures = "0.3" icu_decimal = "2.0.0" icu_locale_core = "2.0.0" diff --git a/codex-rs/core/Cargo.toml b/codex-rs/core/Cargo.toml index f1d16dc405..ff09372113 100644 --- a/codex-rs/core/Cargo.toml +++ b/codex-rs/core/Cargo.toml @@ -81,6 +81,7 @@ openssl-sys = { workspace = true, features = ["vendored"] } [dev-dependencies] assert_cmd = { workspace = true } core_test_support = { workspace = true } +escargot = { workspace = true } maplit = { workspace = true } predicates = { workspace = true } pretty_assertions = { workspace = true } diff --git a/codex-rs/core/src/mcp_connection_manager.rs b/codex-rs/core/src/mcp_connection_manager.rs index 5ec4f05757..5648e20b3b 100644 --- a/codex-rs/core/src/mcp_connection_manager.rs +++ b/codex-rs/core/src/mcp_connection_manager.rs @@ -107,6 +107,9 @@ impl McpClientAdapter { params: mcp_types::InitializeRequestParams, startup_timeout: Duration, ) -> Result { + tracing::error!( + "new_stdio_client use_rmcp_client: {use_rmcp_client} program: {program:?} args: {args:?} env: {env:?} params: {params:?} startup_timeout: {startup_timeout:?}" + ); if use_rmcp_client { let client = Arc::new(RmcpClient::new_stdio_client(program, args, env).await?); client.initialize(params, Some(startup_timeout)).await?; @@ -173,6 +176,8 @@ impl McpConnectionManager { return Ok((Self::default(), ClientStartErrors::default())); } + tracing::error!("new mcp_servers: {mcp_servers:?} use_rmcp_client: {use_rmcp_client}"); + // Launch all configured servers concurrently. let mut join_set = JoinSet::new(); let mut errors = ClientStartErrors::new(); diff --git a/codex-rs/core/tests/suite/rmcp_client.rs b/codex-rs/core/tests/suite/rmcp_client.rs index e5c1a9347f..cc6e972f7b 100644 --- a/codex-rs/core/tests/suite/rmcp_client.rs +++ b/codex-rs/core/tests/suite/rmcp_client.rs @@ -1,6 +1,6 @@ use std::collections::HashMap; +use std::time::Duration; -use assert_cmd::cargo::cargo_bin; use codex_core::protocol::AskForApproval; use codex_core::protocol::EventMsg; use codex_core::protocol::InputItem; @@ -12,6 +12,8 @@ use core_test_support::responses::mount_sse_once; use core_test_support::skip_if_no_network; use core_test_support::test_codex::test_codex; use core_test_support::wait_for_event; +use core_test_support::wait_for_event_with_timeout; +use escargot::CargoBuild; use serde_json::Value; use wiremock::matchers::any; @@ -49,6 +51,13 @@ async fn rmcp_tool_call_round_trip() -> anyhow::Result<()> { .await; let expected_env_value = "propagated-env"; + let rmcp_test_server_bin = CargoBuild::new() + .package("codex-rmcp-client") + .bin("rmcp_test_server") + .run()? + .path() + .to_string_lossy() + .into_owned(); let fixture = test_codex() .with_config(move |config| { @@ -57,13 +66,13 @@ async fn rmcp_tool_call_round_trip() -> anyhow::Result<()> { config.mcp_servers.insert( server_name.to_string(), McpServerConfig { - command: cargo_bin("rmcp_test_server").to_string_lossy().into_owned(), + command: rmcp_test_server_bin.clone(), args: Vec::new(), env: Some(HashMap::from([( "MCP_TEST_VALUE".to_string(), expected_env_value.to_string(), )])), - startup_timeout_sec: None, + startup_timeout_sec: Some(Duration::from_secs(10)), tool_timeout_sec: None, }, ); @@ -88,13 +97,18 @@ async fn rmcp_tool_call_round_trip() -> anyhow::Result<()> { }) .await?; - eprintln!("waiting for begin event"); - let begin_event = wait_for_event(&fixture.codex, |ev| { - matches!(ev, EventMsg::McpToolCallBegin(_)) - }) + eprintln!("waiting for mcp tool call begin event"); + let begin_event = wait_for_event_with_timeout( + &fixture.codex, + |ev| { + eprintln!("ev: {ev:?}"); + matches!(ev, EventMsg::McpToolCallBegin(_)) + }, + Duration::from_secs(10), + ) .await; - eprintln!("begin_event: {begin_event:?}"); + eprintln!("mcp tool call begin event: {begin_event:?}"); let EventMsg::McpToolCallBegin(begin) = begin_event else { unreachable!("event guard guarantees McpToolCallBegin"); }; diff --git a/codex-rs/rmcp-client/src/bin/rmcp_test_server.rs b/codex-rs/rmcp-client/src/bin/rmcp_test_server.rs index 5a14bb2648..d73c063623 100644 --- a/codex-rs/rmcp-client/src/bin/rmcp_test_server.rs +++ b/codex-rs/rmcp-client/src/bin/rmcp_test_server.rs @@ -130,6 +130,7 @@ impl ServerHandler for TestToolServer { #[tokio::main] async fn main() -> Result<(), Box> { + eprintln!("starting rmcp test server"); // Run the server with STDIO transport. If the client disconnects we simply // bubble up the error so the process exits. let service = TestToolServer::new(); From a81a34a4e2b9f26aa0ed40f3c17d4575e53cf7ce Mon Sep 17 00:00:00 2001 From: Gabriel Peal Date: Fri, 26 Sep 2025 09:08:28 -0700 Subject: [PATCH 13/23] PR review --- codex-rs/core/tests/suite/rmcp_client.rs | 2 +- codex-rs/rmcp-client/src/bin/rmcp_test_server.rs | 2 -- codex-rs/rmcp-client/src/rmcp_client.rs | 7 +++---- codex-rs/rmcp-client/src/utils.rs | 2 +- 4 files changed, 5 insertions(+), 8 deletions(-) diff --git a/codex-rs/core/tests/suite/rmcp_client.rs b/codex-rs/core/tests/suite/rmcp_client.rs index cc6e972f7b..2ebe9f011c 100644 --- a/codex-rs/core/tests/suite/rmcp_client.rs +++ b/codex-rs/core/tests/suite/rmcp_client.rs @@ -1,6 +1,7 @@ use std::collections::HashMap; use std::time::Duration; +use codex_core::config_types::McpServerConfig; use codex_core::protocol::AskForApproval; use codex_core::protocol::EventMsg; use codex_core::protocol::InputItem; @@ -61,7 +62,6 @@ async fn rmcp_tool_call_round_trip() -> anyhow::Result<()> { let fixture = test_codex() .with_config(move |config| { - use codex_core::config_types::McpServerConfig; config.use_experimental_use_rmcp_client = true; config.mcp_servers.insert( server_name.to_string(), diff --git a/codex-rs/rmcp-client/src/bin/rmcp_test_server.rs b/codex-rs/rmcp-client/src/bin/rmcp_test_server.rs index d73c063623..23b2f93b38 100644 --- a/codex-rs/rmcp-client/src/bin/rmcp_test_server.rs +++ b/codex-rs/rmcp-client/src/bin/rmcp_test_server.rs @@ -63,8 +63,6 @@ struct EchoArgs { impl ServerHandler for TestToolServer { fn get_info(&self) -> ServerInfo { ServerInfo { - // name: Cow::Borrowed("codex-rmcp-test-server"), - // version: Some(Cow::Borrowed("0.1.0")), capabilities: ServerCapabilities::builder() .enable_tools() .enable_tool_list_changed() diff --git a/codex-rs/rmcp-client/src/rmcp_client.rs b/codex-rs/rmcp-client/src/rmcp_client.rs index 472aa21e6b..c7ac1ecc9a 100644 --- a/codex-rs/rmcp-client/src/rmcp_client.rs +++ b/codex-rs/rmcp-client/src/rmcp_client.rs @@ -150,10 +150,9 @@ impl RmcpClient { timeout: Option, ) -> Result { let service = self.service().await?; - let rmcp_params = match params { - Some(p) => Some(convert_to_rmcp::<_, PaginatedRequestParam>(p)?), - None => None, - }; + let rmcp_params = params + .map(convert_to_rmcp::<_, PaginatedRequestParam>) + .transpose()?; let fut = service.list_tools(rmcp_params); let result = run_with_timeout(fut, timeout, "tools/list").await?; diff --git a/codex-rs/rmcp-client/src/utils.rs b/codex-rs/rmcp-client/src/utils.rs index 500fe9bc6c..6b7bd89424 100644 --- a/codex-rs/rmcp-client/src/utils.rs +++ b/codex-rs/rmcp-client/src/utils.rs @@ -22,7 +22,7 @@ where if let Some(duration) = timeout { let result = time::timeout(duration, fut) .await - .map_err(|_| anyhow!("timed out awaiting {label} after {duration:?}"))?; + .with_context(|| anyhow!("timed out awaiting {label} after {duration:?}"))?; result.map_err(|err| anyhow!("{label} failed: {err}")) } else { fut.await.map_err(|err| anyhow!("{label} failed: {err}")) From da3f3729181d0e8228d9d8504a93aed8ebb68126 Mon Sep 17 00:00:00 2001 From: Gabriel Peal Date: Fri, 26 Sep 2025 09:20:01 -0700 Subject: [PATCH 14/23] Rename test server --- codex-rs/core/tests/suite/rmcp_client.rs | 4 ++-- codex-rs/justfile | 3 +++ .../bin/{rmcp_test_server.rs => rmcp_test_stdio_server.rs} | 0 3 files changed, 5 insertions(+), 2 deletions(-) rename codex-rs/rmcp-client/src/bin/{rmcp_test_server.rs => rmcp_test_stdio_server.rs} (100%) diff --git a/codex-rs/core/tests/suite/rmcp_client.rs b/codex-rs/core/tests/suite/rmcp_client.rs index 2ebe9f011c..7ca078e868 100644 --- a/codex-rs/core/tests/suite/rmcp_client.rs +++ b/codex-rs/core/tests/suite/rmcp_client.rs @@ -19,7 +19,7 @@ use serde_json::Value; use wiremock::matchers::any; #[tokio::test(flavor = "multi_thread", worker_threads = 2)] -async fn rmcp_tool_call_round_trip() -> anyhow::Result<()> { +async fn rmcp_stdio_tool_call_round_trip() -> anyhow::Result<()> { skip_if_no_network!(Ok(())); let server = responses::start_mock_server().await; @@ -54,7 +54,7 @@ async fn rmcp_tool_call_round_trip() -> anyhow::Result<()> { let expected_env_value = "propagated-env"; let rmcp_test_server_bin = CargoBuild::new() .package("codex-rmcp-client") - .bin("rmcp_test_server") + .bin("rmcp_test_stdio_server") .run()? .path() .to_string_lossy() diff --git a/codex-rs/justfile b/codex-rs/justfile index 850737efd6..15d6f15520 100644 --- a/codex-rs/justfile +++ b/codex-rs/justfile @@ -27,6 +27,9 @@ fmt: fix *args: cargo clippy --fix --all-features --tests --allow-dirty "$@" +clippy: + cargo clippy --all-features --tests --allow-dirty "$@" + install: rustup show active-toolchain cargo fetch diff --git a/codex-rs/rmcp-client/src/bin/rmcp_test_server.rs b/codex-rs/rmcp-client/src/bin/rmcp_test_stdio_server.rs similarity index 100% rename from codex-rs/rmcp-client/src/bin/rmcp_test_server.rs rename to codex-rs/rmcp-client/src/bin/rmcp_test_stdio_server.rs From 76298fe49e3f07cb9d9747e83e745173ead90915 Mon Sep 17 00:00:00 2001 From: Gabriel Peal Date: Fri, 26 Sep 2025 09:44:03 -0700 Subject: [PATCH 15/23] Config struct --- codex-rs/cli/src/mcp_cmd.rs | 80 ++++--- codex-rs/cli/tests/mcp_add_remove.rs | 12 +- codex-rs/cli/tests/mcp_list.rs | 18 +- codex-rs/core/src/config.rs | 43 ++-- codex-rs/core/src/config_types.rs | 252 +++++++++++++++----- codex-rs/core/src/mcp_connection_manager.rs | 64 ++--- codex-rs/core/tests/suite/rmcp_client.rs | 7 +- codex-rs/tui/src/history_cell.rs | 14 +- 8 files changed, 321 insertions(+), 169 deletions(-) diff --git a/codex-rs/cli/src/mcp_cmd.rs b/codex-rs/cli/src/mcp_cmd.rs index c931935e0e..588c6276be 100644 --- a/codex-rs/cli/src/mcp_cmd.rs +++ b/codex-rs/cli/src/mcp_cmd.rs @@ -13,6 +13,7 @@ use codex_core::config::find_codex_home; use codex_core::config::load_global_mcp_servers; use codex_core::config::write_global_mcp_servers; use codex_core::config_types::McpServerConfig; +use codex_core::config_types::McpServerTransportConfig; /// [experimental] Launch Codex as an MCP server or manage configured MCP servers. /// @@ -145,11 +146,11 @@ fn run_add(config_overrides: &CliConfigOverrides, add_args: AddArgs) -> Result<( .with_context(|| format!("failed to load MCP servers from {}", codex_home.display()))?; let new_entry = McpServerConfig { - command: Some(command_bin), - args: command_args, + transport: McpServerTransportConfig::Stdio { + command: command_bin, + args: command_args, + }, env: env_map, - url: None, - bearer_token: None, startup_timeout_sec: None, tool_timeout_sec: None, }; @@ -208,13 +209,25 @@ fn run_list(config_overrides: &CliConfigOverrides, list_args: ListArgs) -> Resul .map(|(k, v)| (k.clone(), v.clone())) .collect::>() }); + let transport = match &cfg.transport { + McpServerTransportConfig::Stdio { command, args } => serde_json::json!({ + "type": "stdio", + "command": command, + "args": args, + }), + McpServerTransportConfig::StreamableHttp { url, bearer_token } => { + serde_json::json!({ + "type": "streamable_http", + "url": url, + "bearer_token": bearer_token, + }) + } + }; + serde_json::json!({ "name": name, - "command": cfg.command, - "args": cfg.args, + "transport": transport, "env": env, - "url": cfg.url, - "bearer_token": cfg.bearer_token, "startup_timeout_sec": cfg .startup_timeout_sec .map(|timeout| timeout.as_secs_f64()), @@ -236,12 +249,12 @@ fn run_list(config_overrides: &CliConfigOverrides, list_args: ListArgs) -> Resul let mut rows: Vec<[String; 4]> = Vec::new(); for (name, cfg) in entries { - let (command_display, args_display, env_display) = match (&cfg.command, &cfg.url) { - (Some(command), None) => { - let args = if cfg.args.is_empty() { + let (command_display, args_display, env_display) = match &cfg.transport { + McpServerTransportConfig::Stdio { command, args } => { + let args_display = if args.is_empty() { "-".to_string() } else { - cfg.args.join(" ") + args.join(" ") }; let env_str = match cfg.env.as_ref() { None => "-".to_string(), @@ -256,13 +269,12 @@ fn run_list(config_overrides: &CliConfigOverrides, list_args: ListArgs) -> Resul .join(", ") } }; - (command.clone(), args, env_str) + (command.clone(), args_display, env_str) } - (None, Some(url)) => { - let bearer = cfg.bearer_token.clone().unwrap_or_else(|| "-".to_string()); + McpServerTransportConfig::StreamableHttp { url, bearer_token } => { + let bearer = bearer_token.clone().unwrap_or_else(|| "-".to_string()); (url.clone(), "-".to_string(), bearer) } - _ => ("-".to_string(), "-".to_string(), "-".to_string()), }; rows.push([name.clone(), command_display, args_display, env_display]); @@ -319,13 +331,22 @@ fn run_get(config_overrides: &CliConfigOverrides, get_args: GetArgs) -> Result<( .map(|(k, v)| (k.clone(), v.clone())) .collect::>() }); + let transport = match &server.transport { + McpServerTransportConfig::Stdio { command, args } => serde_json::json!({ + "type": "stdio", + "command": command, + "args": args, + }), + McpServerTransportConfig::StreamableHttp { url, bearer_token } => serde_json::json!({ + "type": "streamable_http", + "url": url, + "bearer_token": bearer_token, + }), + }; let output = serde_json::to_string_pretty(&serde_json::json!({ "name": get_args.name, - "command": server.command, - "args": server.args, + "transport": transport, "env": env, - "url": server.url, - "bearer_token": server.bearer_token, "startup_timeout_sec": server .startup_timeout_sec .map(|timeout| timeout.as_secs_f64()), @@ -338,15 +359,16 @@ fn run_get(config_overrides: &CliConfigOverrides, get_args: GetArgs) -> Result<( } println!("{}", get_args.name); - match (&server.command, &server.url) { - (Some(command), None) => { + match &server.transport { + McpServerTransportConfig::Stdio { command, args } => { + println!(" transport: stdio"); println!(" command: {command}"); - let args = if server.args.is_empty() { + let args_display = if args.is_empty() { "-".to_string() } else { - server.args.join(" ") + args.join(" ") }; - println!(" args: {args}"); + println!(" args: {args_display}"); let env_display = match server.env.as_ref() { None => "-".to_string(), Some(map) if map.is_empty() => "-".to_string(), @@ -362,14 +384,12 @@ fn run_get(config_overrides: &CliConfigOverrides, get_args: GetArgs) -> Result<( }; println!(" env: {env_display}"); } - (None, Some(url)) => { + McpServerTransportConfig::StreamableHttp { url, bearer_token } => { + println!(" transport: streamable_http"); println!(" url: {url}"); - let bearer = server.bearer_token.as_deref().unwrap_or("-"); + let bearer = bearer_token.as_deref().unwrap_or("-"); println!(" bearer_token: {bearer}"); } - _ => { - println!(" command/url: -"); - } } if let Some(timeout) = server.startup_timeout_sec { println!(" startup_timeout_sec: {}", timeout.as_secs_f64()); diff --git a/codex-rs/cli/tests/mcp_add_remove.rs b/codex-rs/cli/tests/mcp_add_remove.rs index 6c328c0472..c433975ac2 100644 --- a/codex-rs/cli/tests/mcp_add_remove.rs +++ b/codex-rs/cli/tests/mcp_add_remove.rs @@ -2,6 +2,7 @@ use std::path::Path; use anyhow::Result; use codex_core::config::load_global_mcp_servers; +use codex_core::config_types::McpServerTransportConfig; use predicates::str::contains; use pretty_assertions::assert_eq; use tempfile::TempDir; @@ -26,11 +27,14 @@ fn add_and_remove_server_updates_global_config() -> Result<()> { let servers = load_global_mcp_servers(codex_home.path())?; assert_eq!(servers.len(), 1); let docs = servers.get("docs").expect("server should exist"); - assert_eq!(docs.command.as_deref(), Some("echo")); - assert_eq!(docs.args, vec!["hello".to_string()]); + match &docs.transport { + McpServerTransportConfig::Stdio { command, args } => { + assert_eq!(command, "echo"); + assert_eq!(args, &vec!["hello".to_string()]); + } + other => panic!("unexpected transport: {other:?}"), + } assert!(docs.env.is_none()); - assert!(docs.url.is_none()); - assert!(docs.bearer_token.is_none()); let mut remove_cmd = codex_command(codex_home.path())?; remove_cmd diff --git a/codex-rs/cli/tests/mcp_list.rs b/codex-rs/cli/tests/mcp_list.rs index e53f42cc8f..b46ee66ef7 100644 --- a/codex-rs/cli/tests/mcp_list.rs +++ b/codex-rs/cli/tests/mcp_list.rs @@ -62,15 +62,22 @@ fn list_and_get_render_expected_output() -> Result<()> { assert_eq!(array.len(), 1); let entry = &array[0]; assert_eq!(entry.get("name"), Some(&JsonValue::String("docs".into()))); + let transport = entry + .get("transport") + .and_then(|value| value.as_object()) + .expect("transport object"); assert_eq!( - entry.get("command"), + transport.get("type"), + Some(&JsonValue::String("stdio".into())) + ); + assert_eq!( + transport.get("command"), Some(&JsonValue::String("docs-server".into())) ); - - let args = entry + let args = transport .get("args") - .and_then(|v| v.as_array()) - .expect("args array"); + .and_then(|value| value.as_array()) + .expect("transport args array"); assert_eq!( args, &vec![ @@ -90,6 +97,7 @@ fn list_and_get_render_expected_output() -> Result<()> { assert!(get_output.status.success()); let stdout = String::from_utf8(get_output.stdout)?; assert!(stdout.contains("docs")); + assert!(stdout.contains("transport: stdio")); assert!(stdout.contains("command: docs-server")); assert!(stdout.contains("args: --port 4000")); assert!(stdout.contains("env: TOKEN=secret")); diff --git a/codex-rs/core/src/config.rs b/codex-rs/core/src/config.rs index e5df088a06..c40ec25b83 100644 --- a/codex-rs/core/src/config.rs +++ b/codex-rs/core/src/config.rs @@ -1,6 +1,7 @@ use crate::config_profile::ConfigProfile; use crate::config_types::History; use crate::config_types::McpServerConfig; +use crate::config_types::McpServerTransportConfig; use crate::config_types::Notifications; use crate::config_types::ReasoningSummaryFormat; use crate::config_types::SandboxWorkspaceWrite; @@ -314,16 +315,16 @@ pub fn write_global_mcp_servers( for (name, config) in servers { let mut entry = TomlTable::new(); entry.set_implicit(false); - match (&config.command, &config.url) { - (Some(command), None) => { + match &config.transport { + McpServerTransportConfig::Stdio { command, args } => { entry["command"] = toml_edit::value(command.clone()); - if !config.args.is_empty() { - let mut args = TomlArray::new(); - for arg in &config.args { - args.push(arg.clone()); + if !args.is_empty() { + let mut args_array = TomlArray::new(); + for arg in args { + args_array.push(arg.clone()); } - entry["args"] = TomlItem::Value(args.into()); + entry["args"] = TomlItem::Value(args_array.into()); } if let Some(env) = &config.env @@ -339,19 +340,12 @@ pub fn write_global_mcp_servers( entry["env"] = TomlItem::Table(env_table); } } - (None, Some(url)) => { + McpServerTransportConfig::StreamableHttp { url, bearer_token } => { entry["url"] = toml_edit::value(url.clone()); - if let Some(token) = &config.bearer_token { + if let Some(token) = bearer_token { entry["bearer_token"] = toml_edit::value(token.clone()); } } - _ => { - tracing::warn!( - "skipping MCP server `{}` with invalid transport configuration", - name - ); - continue; - } } if let Some(timeout) = config.startup_timeout_sec { @@ -1311,11 +1305,11 @@ exclude_slash_tmp = true servers.insert( "docs".to_string(), McpServerConfig { - command: Some("echo".to_string()), - args: vec!["hello".to_string()], + transport: McpServerTransportConfig::Stdio { + command: "echo".to_string(), + args: vec!["hello".to_string()], + }, env: None, - url: None, - bearer_token: None, startup_timeout_sec: Some(Duration::from_secs(3)), tool_timeout_sec: Some(Duration::from_secs(5)), }, @@ -1326,8 +1320,13 @@ exclude_slash_tmp = true let loaded = load_global_mcp_servers(codex_home.path())?; assert_eq!(loaded.len(), 1); let docs = loaded.get("docs").expect("docs entry"); - assert_eq!(docs.command, Some("echo".to_string())); - assert_eq!(docs.args, vec!["hello".to_string()]); + match &docs.transport { + McpServerTransportConfig::Stdio { command, args } => { + assert_eq!(command, "echo"); + assert_eq!(args, &vec!["hello".to_string()]); + } + other => panic!("unexpected transport {other:?}"), + } assert_eq!(docs.startup_timeout_sec, Some(Duration::from_secs(3))); assert_eq!(docs.tool_timeout_sec, Some(Duration::from_secs(5))); diff --git a/codex-rs/core/src/config_types.rs b/codex-rs/core/src/config_types.rs index 51c28505e3..3a420a1cbe 100644 --- a/codex-rs/core/src/config_types.rs +++ b/codex-rs/core/src/config_types.rs @@ -13,22 +13,28 @@ use serde::Deserializer; use serde::Serialize; use serde::de::Error as SerdeError; +#[derive(Serialize, Deserialize, Debug, Clone, PartialEq)] +#[serde(tag = "type", rename_all = "snake_case")] +pub enum McpServerTransportConfig { + Stdio { + #[serde(alias = "commands")] + command: String, + #[serde(default)] + args: Vec, + }, + StreamableHttp { + url: String, + #[serde(default, skip_serializing_if = "Option::is_none")] + bearer_token: Option, + }, +} + #[derive(Serialize, Debug, Clone, PartialEq)] pub struct McpServerConfig { - #[serde(default, skip_serializing_if = "Option::is_none")] - pub command: Option, - - #[serde(default)] - pub args: Vec, - - #[serde(default)] - pub env: Option>, + pub transport: McpServerTransportConfig, #[serde(default, skip_serializing_if = "Option::is_none")] - pub url: Option, - - #[serde(default, skip_serializing_if = "Option::is_none")] - pub bearer_token: Option, + pub env: Option>, /// Startup timeout in seconds for initializing MCP server & initially listing tools. #[serde( @@ -39,7 +45,11 @@ pub struct McpServerConfig { pub startup_timeout_sec: Option, /// Default timeout for MCP tool calls initiated via this server. - #[serde(default, with = "option_duration_secs")] + #[serde( + default, + with = "option_duration_secs", + skip_serializing_if = "Option::is_none" + )] pub tool_timeout_sec: Option, } @@ -50,6 +60,8 @@ impl<'de> Deserialize<'de> for McpServerConfig { { #[derive(Deserialize)] struct RawMcpServerConfig { + #[serde(default)] + transport: Option, #[serde(default)] command: Option, #[serde(default)] @@ -69,8 +81,27 @@ impl<'de> Deserialize<'de> for McpServerConfig { } let raw = RawMcpServerConfig::deserialize(deserializer)?; + let RawMcpServerConfig { + transport: raw_transport, + command, + args, + env, + url, + bearer_token, + startup_timeout_sec, + startup_timeout_ms, + tool_timeout_sec, + } = raw; + + if raw_transport.is_some() + && (command.is_some() || !args.is_empty() || url.is_some() || bearer_token.is_some()) + { + return Err(SerdeError::custom( + "`transport` must not be combined with legacy MCP transport fields", + )); + } - let startup_timeout_sec = match (raw.startup_timeout_sec, raw.startup_timeout_ms) { + let startup_timeout_sec = match (startup_timeout_sec, startup_timeout_ms) { (Some(sec), _) => { let duration = Duration::try_from_secs_f64(sec).map_err(SerdeError::custom)?; Some(duration) @@ -79,45 +110,73 @@ impl<'de> Deserialize<'de> for McpServerConfig { (None, None) => None, }; - let command = raw.command.and_then(normalize_string_option); - let url = raw.url.and_then(normalize_string_option); - - let has_command = command.is_some(); - let has_url = url.is_some(); - - if has_command && has_url { - return Err(SerdeError::custom( - "MCP server config must not set both `command` and `url`", - )); - } + let transport = if let Some(transport) = raw_transport { + match transport { + McpServerTransportConfig::Stdio { command, args } => { + let command = normalize_string_option(command).ok_or_else(|| { + SerdeError::custom( + "MCP server config `transport.stdio.command` must be non-empty", + ) + })?; + McpServerTransportConfig::Stdio { command, args } + } + McpServerTransportConfig::StreamableHttp { url, bearer_token } => { + let url = normalize_string_option(url).ok_or_else(|| { + SerdeError::custom( + "MCP server config `transport.streamable_http.url` must be non-empty", + ) + })?; + let bearer_token = bearer_token.and_then(normalize_string_option); + McpServerTransportConfig::StreamableHttp { url, bearer_token } + } + } + } else { + let command = command.and_then(normalize_string_option); + let url = url.and_then(normalize_string_option); + + match (command, url) { + (Some(command), None) => McpServerTransportConfig::Stdio { command, args }, + (None, Some(url)) => { + if !args.is_empty() { + return Err(SerdeError::custom( + "`args` is not supported when configuring MCP servers via `url`", + )); + } + if env.as_ref().is_some_and(|env| !env.is_empty()) { + return Err(SerdeError::custom( + "`env` is not supported when configuring MCP servers via `url`", + )); + } + + let bearer_token = bearer_token.and_then(normalize_string_option); + McpServerTransportConfig::StreamableHttp { url, bearer_token } + } + (Some(_), Some(_)) => { + return Err(SerdeError::custom( + "MCP server config must not set both `command` and `url`", + )); + } + (None, None) => { + return Err(SerdeError::custom( + "MCP server config must set either `command` or `url` or use `transport`", + )); + } + } + }; - if !has_command && !has_url { + if env.as_ref().is_some_and(|env| !env.is_empty()) + && matches!(transport, McpServerTransportConfig::StreamableHttp { .. }) + { return Err(SerdeError::custom( - "MCP server config must set either `command` or `url`", + "`env` is not supported when configuring MCP servers via `url`", )); } - if has_url { - if !raw.args.is_empty() { - return Err(SerdeError::custom( - "`args` is not supported when configuring MCP servers via `url`", - )); - } - if raw.env.as_ref().is_some_and(|env| !env.is_empty()) { - return Err(SerdeError::custom( - "`env` is not supported when configuring MCP servers via `url`", - )); - } - } - Ok(Self { - command, - args: raw.args, - env: raw.env, - url, - bearer_token: raw.bearer_token.and_then(normalize_string_option), + transport, + env, startup_timeout_sec, - tool_timeout_sec: raw.tool_timeout_sec, + tool_timeout_sec, }) } } @@ -137,7 +196,7 @@ mod tests { use pretty_assertions::assert_eq; #[test] - fn deserialize_command_server_config() { + fn deserialize_legacy_command_server_config() { let cfg: McpServerConfig = toml::from_str( r#" command = "echo" @@ -146,9 +205,54 @@ mod tests { ) .expect("should deserialize command config"); - assert_eq!(cfg.command.as_deref(), Some("echo")); - assert_eq!(cfg.args, vec!["hello", "world"]); - assert!(cfg.url.is_none()); + match cfg.transport { + McpServerTransportConfig::Stdio { command, args } => { + assert_eq!(command, "echo"); + assert_eq!(args, vec!["hello", "world"]); + } + other => panic!("unexpected transport: {other:?}"), + } + } + + #[test] + fn deserialize_transport_stdio_config() { + let cfg: McpServerConfig = toml::from_str( + r#" + [transport] + type = "stdio" + command = "echo" + args = ["hi"] + "#, + ) + .expect("should deserialize stdio transport"); + + match cfg.transport { + McpServerTransportConfig::Stdio { command, args } => { + assert_eq!(command, "echo"); + assert_eq!(args, vec!["hi"]); + } + other => panic!("unexpected transport: {other:?}"), + } + } + + #[test] + fn deserialize_stdio_accepts_commands_alias() { + let cfg: McpServerConfig = toml::from_str( + r#" + [transport] + type = "stdio" + commands = "echo" + "#, + ) + .expect("should deserialize stdio transport with commands alias"); + + match cfg.transport { + McpServerTransportConfig::Stdio { command, args } => { + assert_eq!(command, "echo"); + assert!(args.is_empty()); + } + other => panic!("unexpected transport: {other:?}"), + } } #[test] @@ -161,10 +265,13 @@ mod tests { ) .expect("should deserialize http config"); - assert_eq!(cfg.url.as_deref(), Some("https://example.com/mcp")); - assert_eq!(cfg.bearer_token.as_deref(), Some("secret")); - assert!(cfg.command.is_none()); - assert!(cfg.args.is_empty()); + match cfg.transport { + McpServerTransportConfig::StreamableHttp { url, bearer_token } => { + assert_eq!(url, "https://example.com/mcp"); + assert_eq!(bearer_token.as_deref(), Some("secret")); + } + other => panic!("unexpected transport: {other:?}"), + } assert!(cfg.env.is_none()); } @@ -180,6 +287,43 @@ mod tests { assert!(err.to_string().contains("must not set both")); } + + #[test] + fn deserialize_rejects_mixed_transport_fields() { + let err = toml::from_str::( + r#" + [transport] + type = "stdio" + command = "echo" + args = [] + + command = "echo" + "#, + ) + .expect_err("should reject mixing transport and legacy fields"); + + assert!( + err.to_string() + .contains("must not be combined with legacy MCP transport fields") + ); + } + + #[test] + fn deserialize_rejects_env_for_http_transport() { + let err = toml::from_str::( + r#" + [transport] + type = "streamable_http" + url = "https://example.com" + + [env] + FOO = "BAR" + "#, + ) + .expect_err("should reject env for http transport"); + + assert!(err.to_string().contains("`env` is not supported")); + } } mod option_duration_secs { diff --git a/codex-rs/core/src/mcp_connection_manager.rs b/codex-rs/core/src/mcp_connection_manager.rs index ee89b1834b..370d1cc322 100644 --- a/codex-rs/core/src/mcp_connection_manager.rs +++ b/codex-rs/core/src/mcp_connection_manager.rs @@ -29,6 +29,7 @@ use tracing::info; use tracing::warn; use crate::config_types::McpServerConfig; +use crate::config_types::McpServerTransportConfig; /// Delimiter used to separate the server name from the tool name in a fully /// qualified tool name. @@ -204,32 +205,11 @@ impl McpConnectionManager { continue; } - let has_command = cfg.command.is_some(); - let has_url = cfg.url.is_some(); - - if has_command && has_url { - errors.insert( - server_name.clone(), - anyhow!( - "MCP server `{}` must not set both `command` and `url`", - server_name - ), - ); - continue; - } - - if !has_command && !has_url { - errors.insert( - server_name.clone(), - anyhow!( - "MCP server `{}` must set either `command` or `url`", - server_name - ), - ); - continue; - } - - if cfg.url.is_some() && !use_rmcp_client { + if matches!( + cfg.transport, + McpServerTransportConfig::StreamableHttp { .. } + ) && !use_rmcp_client + { info!( "skipping MCP server `{}` configured with url because rmcp client is disabled", server_name @@ -242,14 +222,7 @@ impl McpConnectionManager { let use_rmcp_client_flag = use_rmcp_client; join_set.spawn(async move { - let McpServerConfig { - command, - args, - env, - url, - bearer_token, - .. - } = cfg; + let McpServerConfig { transport, env, .. } = cfg; let params = mcp_types::InitializeRequestParams { capabilities: ClientCapabilities { experimental: None, @@ -271,8 +244,8 @@ impl McpConnectionManager { protocol_version: mcp_types::MCP_SCHEMA_VERSION.to_owned(), }; - let client = match (command, url) { - (Some(command), None) => { + let client = match transport { + McpServerTransportConfig::Stdio { command, args } => { let command_os: OsString = command.into(); let args_os: Vec = args.into_iter().map(Into::into).collect(); McpClientAdapter::new_stdio_client( @@ -286,15 +259,16 @@ impl McpConnectionManager { .await .map(|c| (c, startup_timeout)) } - (None, Some(url)) => McpClientAdapter::new_streamable_http_client( - url, - bearer_token, - params, - startup_timeout, - ) - .await - .map(|c| (c, startup_timeout)), - _ => Err(anyhow!("invalid MCP server transport configuration")), + McpServerTransportConfig::StreamableHttp { url, bearer_token } => { + McpClientAdapter::new_streamable_http_client( + url, + bearer_token, + params, + startup_timeout, + ) + .await + .map(|c| (c, startup_timeout)) + } }; ((server_name, tool_timeout), client) diff --git a/codex-rs/core/tests/suite/rmcp_client.rs b/codex-rs/core/tests/suite/rmcp_client.rs index 7ca078e868..9981de7209 100644 --- a/codex-rs/core/tests/suite/rmcp_client.rs +++ b/codex-rs/core/tests/suite/rmcp_client.rs @@ -2,6 +2,7 @@ use std::collections::HashMap; use std::time::Duration; use codex_core::config_types::McpServerConfig; +use codex_core::config_types::McpServerTransportConfig; use codex_core::protocol::AskForApproval; use codex_core::protocol::EventMsg; use codex_core::protocol::InputItem; @@ -66,8 +67,10 @@ async fn rmcp_stdio_tool_call_round_trip() -> anyhow::Result<()> { config.mcp_servers.insert( server_name.to_string(), McpServerConfig { - command: rmcp_test_server_bin.clone(), - args: Vec::new(), + transport: McpServerTransportConfig::Stdio { + command: rmcp_test_server_bin.clone(), + args: Vec::new(), + }, env: Some(HashMap::from([( "MCP_TEST_VALUE".to_string(), expected_env_value.to_string(), diff --git a/codex-rs/tui/src/history_cell.rs b/codex-rs/tui/src/history_cell.rs index 1e0d66b336..de719ca5c7 100644 --- a/codex-rs/tui/src/history_cell.rs +++ b/codex-rs/tui/src/history_cell.rs @@ -19,6 +19,7 @@ use crate::wrapping::word_wrap_line; use crate::wrapping::word_wrap_lines; use base64::Engine; use codex_core::config::Config; +use codex_core::config_types::McpServerTransportConfig; use codex_core::config_types::ReasoningSummaryFormat; use codex_core::plan_tool::PlanItemArg; use codex_core::plan_tool::StepStatus; @@ -848,20 +849,19 @@ pub(crate) fn new_mcp_tools_output( lines.push(vec![" • Server: ".into(), server.clone().into()].into()); - match (&cfg.command, &cfg.url) { - (Some(command), None) => { - let args = if cfg.args.is_empty() { + match &cfg.transport { + McpServerTransportConfig::Stdio { command, args } => { + let args_suffix = if args.is_empty() { String::new() } else { - format!(" {}", cfg.args.join(" ")) + format!(" {}", args.join(" ")) }; - let cmd_display = format!("{command}{args}"); + let cmd_display = format!("{command}{args_suffix}"); lines.push(vec![" • Command: ".into(), cmd_display.into()].into()); } - (None, Some(url)) => { + McpServerTransportConfig::StreamableHttp { url, .. } => { lines.push(vec![" • URL: ".into(), url.clone().into()].into()); } - _ => {} } if names.is_empty() { From d808be9022853f9184416ef0fb8690f028a30567 Mon Sep 17 00:00:00 2001 From: Gabriel Peal Date: Fri, 26 Sep 2025 10:01:37 -0700 Subject: [PATCH 16/23] Test server and test passes --- codex-rs/Cargo.lock | 63 ++++- codex-rs/core/tests/suite/rmcp_client.rs | 217 +++++++++++++++++- codex-rs/justfile | 2 +- codex-rs/rmcp-client/Cargo.toml | 2 + ...t_stdio_server.rs => test_stdio_server.rs} | 2 +- .../src/bin/test_streamable_http_server.rs | 167 ++++++++++++++ 6 files changed, 438 insertions(+), 15 deletions(-) rename codex-rs/rmcp-client/src/bin/{rmcp_test_stdio_server.rs => test_stdio_server.rs} (98%) create mode 100644 codex-rs/rmcp-client/src/bin/test_streamable_http_server.rs diff --git a/codex-rs/Cargo.lock b/codex-rs/Cargo.lock index 4aee6b7e4a..eb97878072 100644 --- a/codex-rs/Cargo.lock +++ b/codex-rs/Cargo.lock @@ -333,6 +333,54 @@ version = "1.5.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "c08606f8c3cbf4ce6ec8e28fb0014a2c086708fe954eaa885384a6165172e7e8" +[[package]] +name = "axum" +version = "0.8.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "021e862c184ae977658b36c4500f7feac3221ca5da43e3f25bd04ab6c79a29b5" +dependencies = [ + "axum-core", + "bytes", + "futures-util", + "http", + "http-body", + "http-body-util", + "hyper", + "hyper-util", + "itoa", + "matchit", + "memchr", + "mime", + "percent-encoding", + "pin-project-lite", + "rustversion", + "serde", + "sync_wrapper", + "tokio", + "tower", + "tower-layer", + "tower-service", +] + +[[package]] +name = "axum-core" +version = "0.5.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "68464cd0412f486726fb3373129ef5d2993f90c34bc2bc1c1e9943b2f4fc7ca6" +dependencies = [ + "bytes", + "futures-core", + "http", + "http-body", + "http-body-util", + "mime", + "pin-project-lite", + "rustversion", + "sync_wrapper", + "tower-layer", + "tower-service", +] + [[package]] name = "backtrace" version = "0.3.75" @@ -956,6 +1004,7 @@ name = "codex-rmcp-client" version = "0.0.0" dependencies = [ "anyhow", + "axum", "futures", "mcp-types", "pretty_assertions", @@ -2885,6 +2934,12 @@ dependencies = [ "regex-automata", ] +[[package]] +name = "matchit" +version = "0.8.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "47e1ffaa40ddd1f3ed91f717a33c8c0ee23fff369e3aa8772b9605cc1d22f4c3" + [[package]] name = "mcp-types" version = "0.0.0" @@ -3671,7 +3726,7 @@ dependencies = [ "once_cell", "socket2", "tracing", - "windows-sys 0.52.0", + "windows-sys 0.60.2", ] [[package]] @@ -3909,12 +3964,16 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "534fd1cd0601e798ac30545ff2b7f4a62c6f14edd4aaed1cc5eb1e85f69f09af" dependencies = [ "base64", + "bytes", "chrono", "futures", "http", + "http-body", + "http-body-util", "paste", "pin-project-lite", "process-wrap", + "rand", "reqwest", "rmcp-macros", "schemars 1.0.4", @@ -3925,7 +3984,9 @@ dependencies = [ "tokio", "tokio-stream", "tokio-util", + "tower-service", "tracing", + "uuid", ] [[package]] diff --git a/codex-rs/core/tests/suite/rmcp_client.rs b/codex-rs/core/tests/suite/rmcp_client.rs index 9981de7209..d74438d12e 100644 --- a/codex-rs/core/tests/suite/rmcp_client.rs +++ b/codex-rs/core/tests/suite/rmcp_client.rs @@ -1,4 +1,5 @@ use std::collections::HashMap; +use std::net::TcpListener; use std::time::Duration; use codex_core::config_types::McpServerConfig; @@ -17,6 +18,11 @@ use core_test_support::wait_for_event; use core_test_support::wait_for_event_with_timeout; use escargot::CargoBuild; use serde_json::Value; +use tokio::net::TcpStream; +use tokio::process::Child; +use tokio::process::Command; +use tokio::time::Instant; +use tokio::time::sleep; use wiremock::matchers::any; #[tokio::test(flavor = "multi_thread", worker_threads = 2)] @@ -55,7 +61,7 @@ async fn rmcp_stdio_tool_call_round_trip() -> anyhow::Result<()> { let expected_env_value = "propagated-env"; let rmcp_test_server_bin = CargoBuild::new() .package("codex-rmcp-client") - .bin("rmcp_test_stdio_server") + .bin("test_stdio_server") .run()? .path() .to_string_lossy() @@ -100,18 +106,13 @@ async fn rmcp_stdio_tool_call_round_trip() -> anyhow::Result<()> { }) .await?; - eprintln!("waiting for mcp tool call begin event"); let begin_event = wait_for_event_with_timeout( &fixture.codex, - |ev| { - eprintln!("ev: {ev:?}"); - matches!(ev, EventMsg::McpToolCallBegin(_)) - }, + |ev| matches!(ev, EventMsg::McpToolCallBegin(_)), Duration::from_secs(10), ) .await; - eprintln!("mcp tool call begin event: {begin_event:?}"); let EventMsg::McpToolCallBegin(begin) = begin_event else { unreachable!("event guard guarantees McpToolCallBegin"); }; @@ -122,7 +123,6 @@ async fn rmcp_stdio_tool_call_round_trip() -> anyhow::Result<()> { matches!(ev, EventMsg::McpToolCallEnd(_)) }) .await; - eprintln!("end_event: {end_event:?}"); let EventMsg::McpToolCallEnd(end) = end_event else { unreachable!("event guard guarantees McpToolCallEnd"); }; @@ -148,18 +148,211 @@ async fn rmcp_stdio_tool_call_round_trip() -> anyhow::Result<()> { .get("echo") .and_then(Value::as_str) .expect("echo payload present"); - assert_eq!(echo_value, "ping"); + assert_eq!(echo_value, "ECHOING: ping"); let env_value = map .get("env") .and_then(Value::as_str) .expect("env snapshot inserted"); assert_eq!(env_value, expected_env_value); - let task_complete_event = - wait_for_event(&fixture.codex, |ev| matches!(ev, EventMsg::TaskComplete(_))).await; - eprintln!("task_complete_event: {task_complete_event:?}"); + wait_for_event(&fixture.codex, |ev| matches!(ev, EventMsg::TaskComplete(_))).await; server.verify().await; Ok(()) } + +#[tokio::test(flavor = "multi_thread", worker_threads = 2)] +async fn rmcp_streamable_http_tool_call_round_trip() -> anyhow::Result<()> { + skip_if_no_network!(Ok(())); + + let server = responses::start_mock_server().await; + + let call_id = "call-456"; + let server_name = "rmcp_http"; + let tool_name = format!("{server_name}__echo"); + + mount_sse_once( + &server, + any(), + responses::sse(vec![ + serde_json::json!({ + "type": "response.created", + "response": {"id": "resp-1"} + }), + responses::ev_function_call(call_id, &tool_name, "{\"message\":\"ping\"}"), + responses::ev_completed("resp-1"), + ]), + ) + .await; + mount_sse_once( + &server, + any(), + responses::sse(vec![ + responses::ev_assistant_message( + "msg-1", + "rmcp streamable http echo tool completed successfully.", + ), + responses::ev_completed("resp-2"), + ]), + ) + .await; + + let expected_env_value = "propagated-env-http"; + let rmcp_http_server_bin = CargoBuild::new() + .package("codex-rmcp-client") + .bin("test_streamable_http_server") + .run()? + .path() + .to_string_lossy() + .into_owned(); + + let listener = TcpListener::bind("127.0.0.1:0")?; + let port = listener.local_addr()?.port(); + drop(listener); + let bind_addr = format!("127.0.0.1:{port}"); + let server_url = format!("http://{bind_addr}/mcp"); + + let mut http_server_child = Command::new(&rmcp_http_server_bin) + .kill_on_drop(true) + .env("MCP_STREAMABLE_HTTP_BIND_ADDR", &bind_addr) + .env("MCP_TEST_VALUE", expected_env_value) + .spawn()?; + + wait_for_streamable_http_server(&mut http_server_child, &bind_addr, Duration::from_secs(5)) + .await?; + + let fixture = test_codex() + .with_config(move |config| { + config.use_experimental_use_rmcp_client = true; + config.mcp_servers.insert( + server_name.to_string(), + McpServerConfig { + transport: McpServerTransportConfig::StreamableHttp { + url: server_url, + bearer_token: None, + }, + env: None, + startup_timeout_sec: Some(Duration::from_secs(10)), + tool_timeout_sec: None, + }, + ); + }) + .build(&server) + .await?; + let session_model = fixture.session_configured.model.clone(); + + fixture + .codex + .submit(Op::UserTurn { + items: vec![InputItem::Text { + text: "call the rmcp streamable http echo tool".into(), + }], + final_output_json_schema: None, + cwd: fixture.cwd.path().to_path_buf(), + approval_policy: AskForApproval::Never, + sandbox_policy: SandboxPolicy::DangerFullAccess, + model: session_model, + effort: None, + summary: ReasoningSummary::Auto, + }) + .await?; + + let begin_event = wait_for_event_with_timeout( + &fixture.codex, + |ev| matches!(ev, EventMsg::McpToolCallBegin(_)), + Duration::from_secs(10), + ) + .await; + + let EventMsg::McpToolCallBegin(begin) = begin_event else { + unreachable!("event guard guarantees McpToolCallBegin"); + }; + assert_eq!(begin.invocation.server, server_name); + assert_eq!(begin.invocation.tool, "echo"); + + let end_event = wait_for_event(&fixture.codex, |ev| { + matches!(ev, EventMsg::McpToolCallEnd(_)) + }) + .await; + let EventMsg::McpToolCallEnd(end) = end_event else { + unreachable!("event guard guarantees McpToolCallEnd"); + }; + + let result = end + .result + .as_ref() + .expect("rmcp echo tool should return success"); + assert_eq!(result.is_error, Some(false)); + assert!( + result.content.is_empty(), + "content should default to an empty array" + ); + + let structured = result + .structured_content + .as_ref() + .expect("structured content"); + let Value::Object(map) = structured else { + panic!("structured content should be an object: {structured:?}"); + }; + let echo_value = map + .get("echo") + .and_then(Value::as_str) + .expect("echo payload present"); + assert_eq!(echo_value, "ECHOING: ping"); + let env_value = map + .get("env") + .and_then(Value::as_str) + .expect("env snapshot inserted"); + assert_eq!(env_value, expected_env_value); + + wait_for_event(&fixture.codex, |ev| matches!(ev, EventMsg::TaskComplete(_))).await; + + server.verify().await; + + match http_server_child.try_wait() { + Ok(Some(_)) => {} + Ok(None) => { + let _ = http_server_child.kill().await; + } + Err(error) => { + eprintln!("failed to check streamable http server status: {error}"); + let _ = http_server_child.kill().await; + } + } + if let Err(error) = http_server_child.wait().await { + eprintln!("failed to await streamable http server shutdown: {error}"); + } + + Ok(()) +} + +async fn wait_for_streamable_http_server( + server_child: &mut Child, + address: &str, + timeout: Duration, +) -> anyhow::Result<()> { + let deadline = Instant::now() + timeout; + + loop { + if let Some(status) = server_child.try_wait()? { + return Err(anyhow::anyhow!( + "streamable HTTP server exited early with status {status}" + )); + } + + match TcpStream::connect(address).await { + Ok(_) => return Ok(()), + Err(error) => { + if Instant::now() >= deadline { + return Err(anyhow::anyhow!( + "timed out waiting for streamable HTTP server at {address}: {error}" + )); + } + } + } + + sleep(Duration::from_millis(50)).await; + } +} diff --git a/codex-rs/justfile b/codex-rs/justfile index 15d6f15520..4829da3606 100644 --- a/codex-rs/justfile +++ b/codex-rs/justfile @@ -28,7 +28,7 @@ fix *args: cargo clippy --fix --all-features --tests --allow-dirty "$@" clippy: - cargo clippy --all-features --tests --allow-dirty "$@" + cargo clippy --all-features --tests "$@" install: rustup show active-toolchain diff --git a/codex-rs/rmcp-client/Cargo.toml b/codex-rs/rmcp-client/Cargo.toml index c2abf0db08..a377b94f07 100644 --- a/codex-rs/rmcp-client/Cargo.toml +++ b/codex-rs/rmcp-client/Cargo.toml @@ -17,7 +17,9 @@ rmcp = { version = "0.7.0", default-features = false, features = [ "server", "transport-child-process", "transport-streamable-http-client-reqwest", + "transport-streamable-http-server", ] } +axum = { version = "0.8", default-features = false, features = ["http1", "tokio"] } futures = { version = "0.3", default-features = false, features = ["std"] } reqwest = { version = "0.12", default-features = false, features = [ "json", diff --git a/codex-rs/rmcp-client/src/bin/rmcp_test_stdio_server.rs b/codex-rs/rmcp-client/src/bin/test_stdio_server.rs similarity index 98% rename from codex-rs/rmcp-client/src/bin/rmcp_test_stdio_server.rs rename to codex-rs/rmcp-client/src/bin/test_stdio_server.rs index 23b2f93b38..2d380fa54e 100644 --- a/codex-rs/rmcp-client/src/bin/rmcp_test_stdio_server.rs +++ b/codex-rs/rmcp-client/src/bin/test_stdio_server.rs @@ -107,7 +107,7 @@ impl ServerHandler for TestToolServer { let env_snapshot: HashMap = std::env::vars().collect(); let structured_content = json!({ - "echo": args.message, + "echo": format!("ECHOING: {}", args.message), "env": env_snapshot.get("MCP_TEST_VALUE"), }); diff --git a/codex-rs/rmcp-client/src/bin/test_streamable_http_server.rs b/codex-rs/rmcp-client/src/bin/test_streamable_http_server.rs new file mode 100644 index 0000000000..eedd1cb1e3 --- /dev/null +++ b/codex-rs/rmcp-client/src/bin/test_streamable_http_server.rs @@ -0,0 +1,167 @@ +use std::borrow::Cow; +use std::collections::HashMap; +use std::io::ErrorKind; +use std::net::SocketAddr; +use std::sync::Arc; + +use axum::Router; +use rmcp::ErrorData as McpError; +use rmcp::handler::server::ServerHandler; +use rmcp::model::CallToolRequestParam; +use rmcp::model::CallToolResult; +use rmcp::model::JsonObject; +use rmcp::model::ListToolsResult; +use rmcp::model::PaginatedRequestParam; +use rmcp::model::ServerCapabilities; +use rmcp::model::ServerInfo; +use rmcp::model::Tool; +use rmcp::transport::StreamableHttpServerConfig; +use rmcp::transport::StreamableHttpService; +use rmcp::transport::streamable_http_server::session::local::LocalSessionManager; +use serde::Deserialize; +use serde_json::json; +use tokio::task; + +#[derive(Clone)] +struct TestToolServer { + tools: Arc>, +} + +impl TestToolServer { + fn new() -> Self { + let tools = vec![Self::echo_tool()]; + Self { + tools: Arc::new(tools), + } + } + + fn echo_tool() -> Tool { + #[expect(clippy::expect_used)] + let schema: JsonObject = serde_json::from_value(json!({ + "type": "object", + "properties": { + "message": { "type": "string" }, + "env_var": { "type": "string" } + }, + "required": ["message"], + "additionalProperties": false + })) + .expect("echo tool schema should deserialize"); + + Tool::new( + Cow::Borrowed("echo"), + Cow::Borrowed("Echo back the provided message and include environment data."), + Arc::new(schema), + ) + } +} + +#[derive(Deserialize)] +struct EchoArgs { + message: String, + #[allow(dead_code)] + env_var: Option, +} + +impl ServerHandler for TestToolServer { + fn get_info(&self) -> ServerInfo { + ServerInfo { + capabilities: ServerCapabilities::builder() + .enable_tools() + .enable_tool_list_changed() + .build(), + ..ServerInfo::default() + } + } + + fn list_tools( + &self, + _request: Option, + _context: rmcp::service::RequestContext, + ) -> impl std::future::Future> + Send + '_ { + let tools = self.tools.clone(); + async move { + Ok(ListToolsResult { + tools: (*tools).clone(), + next_cursor: None, + }) + } + } + + async fn call_tool( + &self, + request: CallToolRequestParam, + _context: rmcp::service::RequestContext, + ) -> Result { + match request.name.as_ref() { + "echo" => { + let args: EchoArgs = match request.arguments { + Some(arguments) => serde_json::from_value(serde_json::Value::Object( + arguments.into_iter().collect(), + )) + .map_err(|err| McpError::invalid_params(err.to_string(), None))?, + None => { + return Err(McpError::invalid_params( + "missing arguments for echo tool", + None, + )); + } + }; + + let env_snapshot: HashMap = std::env::vars().collect(); + let structured_content = json!({ + "echo": format!("ECHOING: {}", args.message), + "env": env_snapshot.get("MCP_TEST_VALUE"), + }); + + Ok(CallToolResult { + content: Vec::new(), + structured_content: Some(structured_content), + is_error: Some(false), + meta: None, + }) + } + other => Err(McpError::invalid_params( + format!("unknown tool: {other}"), + None, + )), + } + } +} + +fn parse_bind_addr() -> Result> { + let default_addr = "127.0.0.1:3920"; + let bind_addr = std::env::var("MCP_STREAMABLE_HTTP_BIND_ADDR") + .or_else(|_| std::env::var("BIND_ADDR")) + .unwrap_or_else(|_| default_addr.to_string()); + Ok(bind_addr.parse()?) +} + +#[tokio::main] +async fn main() -> Result<(), Box> { + let bind_addr = parse_bind_addr()?; + let listener = match tokio::net::TcpListener::bind(&bind_addr).await { + Ok(listener) => listener, + Err(err) if err.kind() == ErrorKind::PermissionDenied => { + eprintln!( + "failed to bind to {bind_addr}: {err}. make sure the process has network access" + ); + return Ok(()); + } + Err(err) => return Err(err.into()), + }; + eprintln!("starting rmcp streamable http test server on http://{bind_addr}/mcp"); + + let router = Router::new().nest_service( + "/mcp", + StreamableHttpService::new( + || Ok(TestToolServer::new()), + Arc::new(LocalSessionManager::default()), + StreamableHttpServerConfig::default(), + ), + ); + + axum::serve(listener, router).await?; + task::yield_now().await; + Ok(()) +} From 90937f96b4533fd7f146c84e3e82bc0146907f7f Mon Sep 17 00:00:00 2001 From: Gabriel Peal Date: Fri, 26 Sep 2025 11:22:45 -0700 Subject: [PATCH 17/23] Cleanup --- codex-rs/cli/src/mcp_cmd.rs | 143 +++++++++++++------- codex-rs/cli/tests/mcp_add_remove.rs | 9 +- codex-rs/cli/tests/mcp_list.rs | 8 ++ codex-rs/core/src/config.rs | 9 +- codex-rs/core/src/config_types.rs | 139 +++++++++++++++++-- codex-rs/core/src/mcp_connection_manager.rs | 4 +- codex-rs/core/tests/suite/rmcp_client.rs | 19 +-- codex-rs/justfile | 1 + codex-rs/tui/src/history_cell.rs | 2 +- 9 files changed, 250 insertions(+), 84 deletions(-) diff --git a/codex-rs/cli/src/mcp_cmd.rs b/codex-rs/cli/src/mcp_cmd.rs index 588c6276be..9c5923c8b5 100644 --- a/codex-rs/cli/src/mcp_cmd.rs +++ b/codex-rs/cli/src/mcp_cmd.rs @@ -149,8 +149,8 @@ fn run_add(config_overrides: &CliConfigOverrides, add_args: AddArgs) -> Result<( transport: McpServerTransportConfig::Stdio { command: command_bin, args: command_args, + env: env_map, }, - env: env_map, startup_timeout_sec: None, tool_timeout_sec: None, }; @@ -204,16 +204,20 @@ fn run_list(config_overrides: &CliConfigOverrides, list_args: ListArgs) -> Resul let json_entries: Vec<_> = entries .into_iter() .map(|(name, cfg)| { - let env = cfg.env.as_ref().map(|env| { - env.iter() - .map(|(k, v)| (k.clone(), v.clone())) - .collect::>() - }); + let env = match &cfg.transport { + McpServerTransportConfig::Stdio { env, .. } => env.as_ref().map(|env| { + env.iter() + .map(|(k, v)| (k.clone(), v.clone())) + .collect::>() + }), + McpServerTransportConfig::StreamableHttp { .. } => None, + }; let transport = match &cfg.transport { - McpServerTransportConfig::Stdio { command, args } => serde_json::json!({ + McpServerTransportConfig::Stdio { command, args, env } => serde_json::json!({ "type": "stdio", "command": command, "args": args, + "env": env, }), McpServerTransportConfig::StreamableHttp { url, bearer_token } => { serde_json::json!({ @@ -247,16 +251,18 @@ fn run_list(config_overrides: &CliConfigOverrides, list_args: ListArgs) -> Resul return Ok(()); } - let mut rows: Vec<[String; 4]> = Vec::new(); + let mut stdio_rows: Vec<[String; 4]> = Vec::new(); + let mut http_rows: Vec<[String; 3]> = Vec::new(); + for (name, cfg) in entries { - let (command_display, args_display, env_display) = match &cfg.transport { - McpServerTransportConfig::Stdio { command, args } => { + match &cfg.transport { + McpServerTransportConfig::Stdio { command, args, env } => { let args_display = if args.is_empty() { "-".to_string() } else { args.join(" ") }; - let env_str = match cfg.env.as_ref() { + let env_display = match env.as_ref() { None => "-".to_string(), Some(map) if map.is_empty() => "-".to_string(), Some(map) => { @@ -269,48 +275,87 @@ fn run_list(config_overrides: &CliConfigOverrides, list_args: ListArgs) -> Resul .join(", ") } }; - (command.clone(), args_display, env_str) + stdio_rows.push([name.clone(), command.clone(), args_display, env_display]); } McpServerTransportConfig::StreamableHttp { url, bearer_token } => { - let bearer = bearer_token.clone().unwrap_or_else(|| "-".to_string()); - (url.clone(), "-".to_string(), bearer) + let has_bearer = if bearer_token.is_some() { + "True" + } else { + "False" + }; + http_rows.push([name.clone(), url.clone(), has_bearer.into()]); } - }; - - rows.push([name.clone(), command_display, args_display, env_display]); + } } - let mut widths = ["Name".len(), "Command".len(), "Args".len(), "Env".len()]; - for row in &rows { - for (i, cell) in row.iter().enumerate() { - widths[i] = widths[i].max(cell.len()); + if !stdio_rows.is_empty() { + let mut widths = ["Name".len(), "Command".len(), "Args".len(), "Env".len()]; + for row in &stdio_rows { + for (i, cell) in row.iter().enumerate() { + widths[i] = widths[i].max(cell.len()); + } } - } - println!( - "{: Result<( }; if get_args.json { - let env = server.env.as_ref().map(|env| { - env.iter() - .map(|(k, v)| (k.clone(), v.clone())) - .collect::>() - }); + let env = match &server.transport { + McpServerTransportConfig::Stdio { env, .. } => env.as_ref().map(|env| { + env.iter() + .map(|(k, v)| (k.clone(), v.clone())) + .collect::>() + }), + McpServerTransportConfig::StreamableHttp { .. } => None, + }; let transport = match &server.transport { - McpServerTransportConfig::Stdio { command, args } => serde_json::json!({ + McpServerTransportConfig::Stdio { command, args, env } => serde_json::json!({ "type": "stdio", "command": command, "args": args, + "env": env, }), McpServerTransportConfig::StreamableHttp { url, bearer_token } => serde_json::json!({ "type": "streamable_http", @@ -360,7 +409,7 @@ fn run_get(config_overrides: &CliConfigOverrides, get_args: GetArgs) -> Result<( println!("{}", get_args.name); match &server.transport { - McpServerTransportConfig::Stdio { command, args } => { + McpServerTransportConfig::Stdio { command, args, env } => { println!(" transport: stdio"); println!(" command: {command}"); let args_display = if args.is_empty() { @@ -369,7 +418,7 @@ fn run_get(config_overrides: &CliConfigOverrides, get_args: GetArgs) -> Result<( args.join(" ") }; println!(" args: {args_display}"); - let env_display = match server.env.as_ref() { + let env_display = match env.as_ref() { None => "-".to_string(), Some(map) if map.is_empty() => "-".to_string(), Some(map) => { diff --git a/codex-rs/cli/tests/mcp_add_remove.rs b/codex-rs/cli/tests/mcp_add_remove.rs index c433975ac2..cf3ea9f739 100644 --- a/codex-rs/cli/tests/mcp_add_remove.rs +++ b/codex-rs/cli/tests/mcp_add_remove.rs @@ -28,13 +28,13 @@ fn add_and_remove_server_updates_global_config() -> Result<()> { assert_eq!(servers.len(), 1); let docs = servers.get("docs").expect("server should exist"); match &docs.transport { - McpServerTransportConfig::Stdio { command, args } => { + McpServerTransportConfig::Stdio { command, args, env } => { assert_eq!(command, "echo"); assert_eq!(args, &vec!["hello".to_string()]); + assert!(env.is_none()); } other => panic!("unexpected transport: {other:?}"), } - assert!(docs.env.is_none()); let mut remove_cmd = codex_command(codex_home.path())?; remove_cmd @@ -82,7 +82,10 @@ fn add_with_env_preserves_key_order_and_values() -> Result<()> { let servers = load_global_mcp_servers(codex_home.path())?; let envy = servers.get("envy").expect("server should exist"); - let env = envy.env.as_ref().expect("env should be present"); + let env = match &envy.transport { + McpServerTransportConfig::Stdio { env: Some(env), .. } => env, + other => panic!("unexpected transport: {other:?}"), + }; assert_eq!(env.len(), 2); assert_eq!(env.get("FOO"), Some(&"bar".to_string())); diff --git a/codex-rs/cli/tests/mcp_list.rs b/codex-rs/cli/tests/mcp_list.rs index b46ee66ef7..b15d11faa8 100644 --- a/codex-rs/cli/tests/mcp_list.rs +++ b/codex-rs/cli/tests/mcp_list.rs @@ -74,6 +74,14 @@ fn list_and_get_render_expected_output() -> Result<()> { transport.get("command"), Some(&JsonValue::String("docs-server".into())) ); + let transport_env = transport + .get("env") + .and_then(|value| value.as_object()) + .expect("transport env map"); + assert_eq!( + transport_env.get("TOKEN"), + Some(&JsonValue::String("secret".into())) + ); let args = transport .get("args") .and_then(|value| value.as_array()) diff --git a/codex-rs/core/src/config.rs b/codex-rs/core/src/config.rs index c40ec25b83..864e9f9b17 100644 --- a/codex-rs/core/src/config.rs +++ b/codex-rs/core/src/config.rs @@ -316,7 +316,7 @@ pub fn write_global_mcp_servers( let mut entry = TomlTable::new(); entry.set_implicit(false); match &config.transport { - McpServerTransportConfig::Stdio { command, args } => { + McpServerTransportConfig::Stdio { command, args, env } => { entry["command"] = toml_edit::value(command.clone()); if !args.is_empty() { @@ -327,7 +327,7 @@ pub fn write_global_mcp_servers( entry["args"] = TomlItem::Value(args_array.into()); } - if let Some(env) = &config.env + if let Some(env) = env && !env.is_empty() { let mut env_table = TomlTable::new(); @@ -1308,8 +1308,8 @@ exclude_slash_tmp = true transport: McpServerTransportConfig::Stdio { command: "echo".to_string(), args: vec!["hello".to_string()], + env: None, }, - env: None, startup_timeout_sec: Some(Duration::from_secs(3)), tool_timeout_sec: Some(Duration::from_secs(5)), }, @@ -1321,9 +1321,10 @@ exclude_slash_tmp = true assert_eq!(loaded.len(), 1); let docs = loaded.get("docs").expect("docs entry"); match &docs.transport { - McpServerTransportConfig::Stdio { command, args } => { + McpServerTransportConfig::Stdio { command, args, env } => { assert_eq!(command, "echo"); assert_eq!(args, &vec!["hello".to_string()]); + assert!(env.is_none()); } other => panic!("unexpected transport {other:?}"), } diff --git a/codex-rs/core/src/config_types.rs b/codex-rs/core/src/config_types.rs index 3a420a1cbe..c04c32b8a3 100644 --- a/codex-rs/core/src/config_types.rs +++ b/codex-rs/core/src/config_types.rs @@ -16,12 +16,16 @@ use serde::de::Error as SerdeError; #[derive(Serialize, Deserialize, Debug, Clone, PartialEq)] #[serde(tag = "type", rename_all = "snake_case")] pub enum McpServerTransportConfig { + /// https://modelcontextprotocol.io/specification/2025-06-18/basic/transports#stdio Stdio { #[serde(alias = "commands")] command: String, #[serde(default)] args: Vec, + #[serde(default, skip_serializing_if = "Option::is_none")] + env: Option>, }, + /// https://modelcontextprotocol.io/specification/2025-06-18/basic/transports#streamable-http StreamableHttp { url: String, #[serde(default, skip_serializing_if = "Option::is_none")] @@ -33,9 +37,6 @@ pub enum McpServerTransportConfig { pub struct McpServerConfig { pub transport: McpServerTransportConfig, - #[serde(default, skip_serializing_if = "Option::is_none")] - pub env: Option>, - /// Startup timeout in seconds for initializing MCP server & initially listing tools. #[serde( default, @@ -85,7 +86,7 @@ impl<'de> Deserialize<'de> for McpServerConfig { transport: raw_transport, command, args, - env, + env: top_level_env, url, bearer_token, startup_timeout_sec, @@ -110,15 +111,23 @@ impl<'de> Deserialize<'de> for McpServerConfig { (None, None) => None, }; + let mut top_level_env = top_level_env; + let transport = if let Some(transport) = raw_transport { match transport { - McpServerTransportConfig::Stdio { command, args } => { + McpServerTransportConfig::Stdio { command, args, env } => { let command = normalize_string_option(command).ok_or_else(|| { SerdeError::custom( "MCP server config `transport.stdio.command` must be non-empty", ) })?; - McpServerTransportConfig::Stdio { command, args } + let env = normalize_env_option(env); + if top_level_env.is_some() && env.is_some() { + return Err(SerdeError::custom( + "MCP server config must not set both `env` and `transport.stdio.env`", + )); + } + McpServerTransportConfig::Stdio { command, args, env } } McpServerTransportConfig::StreamableHttp { url, bearer_token } => { let url = normalize_string_option(url).ok_or_else(|| { @@ -135,14 +144,18 @@ impl<'de> Deserialize<'de> for McpServerConfig { let url = url.and_then(normalize_string_option); match (command, url) { - (Some(command), None) => McpServerTransportConfig::Stdio { command, args }, + (Some(command), None) => McpServerTransportConfig::Stdio { + command, + args, + env: normalize_env_option(top_level_env.take()), + }, (None, Some(url)) => { if !args.is_empty() { return Err(SerdeError::custom( "`args` is not supported when configuring MCP servers via `url`", )); } - if env.as_ref().is_some_and(|env| !env.is_empty()) { + if top_level_env.as_ref().is_some_and(|env| !env.is_empty()) { return Err(SerdeError::custom( "`env` is not supported when configuring MCP servers via `url`", )); @@ -164,7 +177,7 @@ impl<'de> Deserialize<'de> for McpServerConfig { } }; - if env.as_ref().is_some_and(|env| !env.is_empty()) + if top_level_env.as_ref().is_some_and(|env| !env.is_empty()) && matches!(transport, McpServerTransportConfig::StreamableHttp { .. }) { return Err(SerdeError::custom( @@ -172,9 +185,30 @@ impl<'de> Deserialize<'de> for McpServerConfig { )); } + let transport = match (transport, top_level_env.take()) { + (mut transport @ McpServerTransportConfig::Stdio { .. }, Some(env)) => { + if !env.is_empty() { + match &mut transport { + McpServerTransportConfig::Stdio { + env: target_env, .. + } => { + if target_env.is_some() { + return Err(SerdeError::custom( + "MCP server config must not set both `env` and `transport.stdio.env`", + )); + } + *target_env = Some(env); + } + McpServerTransportConfig::StreamableHttp { .. } => unreachable!(), + } + } + transport + } + (transport, _) => transport, + }; + Ok(Self { transport, - env, startup_timeout_sec, tool_timeout_sec, }) @@ -190,6 +224,10 @@ fn normalize_string_option(value: String) -> Option { } } +fn normalize_env_option(env: Option>) -> Option> { + env.and_then(|env| if env.is_empty() { None } else { Some(env) }) +} + #[cfg(test)] mod tests { use super::*; @@ -206,9 +244,33 @@ mod tests { .expect("should deserialize command config"); match cfg.transport { - McpServerTransportConfig::Stdio { command, args } => { + McpServerTransportConfig::Stdio { command, args, env } => { assert_eq!(command, "echo"); assert_eq!(args, vec!["hello", "world"]); + assert!(env.is_none()); + } + other => panic!("unexpected transport: {other:?}"), + } + } + + #[test] + fn deserialize_legacy_command_server_config_with_env() { + let cfg: McpServerConfig = toml::from_str( + r#" + command = "echo" + + [env] + FOO = "BAR" + "#, + ) + .expect("should deserialize command config with env"); + + match cfg.transport { + McpServerTransportConfig::Stdio { command, args, env } => { + assert_eq!(command, "echo"); + assert!(args.is_empty()); + let env = env.expect("env should be set"); + assert_eq!(env.get("FOO"), Some(&"BAR".to_string())); } other => panic!("unexpected transport: {other:?}"), } @@ -227,9 +289,35 @@ mod tests { .expect("should deserialize stdio transport"); match cfg.transport { - McpServerTransportConfig::Stdio { command, args } => { + McpServerTransportConfig::Stdio { command, args, env } => { assert_eq!(command, "echo"); assert_eq!(args, vec!["hi"]); + assert!(env.is_none()); + } + other => panic!("unexpected transport: {other:?}"), + } + } + + #[test] + fn deserialize_transport_stdio_config_with_env() { + let cfg: McpServerConfig = toml::from_str( + r#" + [transport] + type = "stdio" + command = "echo" + + [transport.env] + FOO = "BAR" + "#, + ) + .expect("should deserialize stdio transport with env"); + + match cfg.transport { + McpServerTransportConfig::Stdio { command, args, env } => { + assert_eq!(command, "echo"); + assert!(args.is_empty()); + let env = env.expect("env should be present"); + assert_eq!(env.get("FOO"), Some(&"BAR".to_string())); } other => panic!("unexpected transport: {other:?}"), } @@ -247,9 +335,10 @@ mod tests { .expect("should deserialize stdio transport with commands alias"); match cfg.transport { - McpServerTransportConfig::Stdio { command, args } => { + McpServerTransportConfig::Stdio { command, args, env } => { assert_eq!(command, "echo"); assert!(args.is_empty()); + assert!(env.is_none()); } other => panic!("unexpected transport: {other:?}"), } @@ -272,7 +361,6 @@ mod tests { } other => panic!("unexpected transport: {other:?}"), } - assert!(cfg.env.is_none()); } #[test] @@ -324,6 +412,29 @@ mod tests { assert!(err.to_string().contains("`env` is not supported")); } + + #[test] + fn deserialize_rejects_duplicate_env_definitions() { + let err = toml::from_str::( + r#" + [transport] + type = "stdio" + command = "echo" + + [transport.env] + FOO = "BAR" + + [env] + BAZ = "QUX" + "#, + ) + .expect_err("should reject duplicate env definitions"); + + assert!( + err.to_string() + .contains("MCP server config must not set both `env` and `transport.stdio.env`") + ); + } } mod option_duration_secs { diff --git a/codex-rs/core/src/mcp_connection_manager.rs b/codex-rs/core/src/mcp_connection_manager.rs index 370d1cc322..b5427ba1aa 100644 --- a/codex-rs/core/src/mcp_connection_manager.rs +++ b/codex-rs/core/src/mcp_connection_manager.rs @@ -222,7 +222,7 @@ impl McpConnectionManager { let use_rmcp_client_flag = use_rmcp_client; join_set.spawn(async move { - let McpServerConfig { transport, env, .. } = cfg; + let McpServerConfig { transport, .. } = cfg; let params = mcp_types::InitializeRequestParams { capabilities: ClientCapabilities { experimental: None, @@ -245,7 +245,7 @@ impl McpConnectionManager { }; let client = match transport { - McpServerTransportConfig::Stdio { command, args } => { + McpServerTransportConfig::Stdio { command, args, env } => { let command_os: OsString = command.into(); let args_os: Vec = args.into_iter().map(Into::into).collect(); McpClientAdapter::new_stdio_client( diff --git a/codex-rs/core/tests/suite/rmcp_client.rs b/codex-rs/core/tests/suite/rmcp_client.rs index 4dd1c54e65..89e4622bf4 100644 --- a/codex-rs/core/tests/suite/rmcp_client.rs +++ b/codex-rs/core/tests/suite/rmcp_client.rs @@ -4,9 +4,7 @@ use std::time::Duration; use codex_core::config_types::McpServerConfig; use codex_core::config_types::McpServerTransportConfig; -use std::time::Duration; -use codex_core::config_types::McpServerConfig; use codex_core::protocol::AskForApproval; use codex_core::protocol::EventMsg; use codex_core::protocol::InputItem; @@ -29,11 +27,7 @@ use tokio::time::sleep; use wiremock::matchers::any; #[tokio::test(flavor = "multi_thread", worker_threads = 2)] -async fn rmcp_stdio_tool_call_round_trip() -> anyhow::Result<()> { -use wiremock::matchers::any; - -#[tokio::test(flavor = "multi_thread", worker_threads = 2)] -async fn rmcp_tool_call_round_trip() -> anyhow::Result<()> { +async fn stdio_server_round_trip() -> anyhow::Result<()> { skip_if_no_network!(Ok(())); let server = responses::start_mock_server().await; @@ -83,11 +77,11 @@ async fn rmcp_tool_call_round_trip() -> anyhow::Result<()> { transport: McpServerTransportConfig::Stdio { command: rmcp_test_server_bin.clone(), args: Vec::new(), + env: Some(HashMap::from([( + "MCP_TEST_VALUE".to_string(), + expected_env_value.to_string(), + )])), }, - env: Some(HashMap::from([( - "MCP_TEST_VALUE".to_string(), - expected_env_value.to_string(), - )])), startup_timeout_sec: Some(Duration::from_secs(10)), tool_timeout_sec: None, }, @@ -170,7 +164,7 @@ async fn rmcp_tool_call_round_trip() -> anyhow::Result<()> { } #[tokio::test(flavor = "multi_thread", worker_threads = 2)] -async fn rmcp_streamable_http_tool_call_round_trip() -> anyhow::Result<()> { +async fn streamable_http_tool_call_round_trip() -> anyhow::Result<()> { skip_if_no_network!(Ok(())); let server = responses::start_mock_server().await; @@ -239,7 +233,6 @@ async fn rmcp_streamable_http_tool_call_round_trip() -> anyhow::Result<()> { url: server_url, bearer_token: None, }, - env: None, startup_timeout_sec: Some(Duration::from_secs(10)), tool_timeout_sec: None, }, diff --git a/codex-rs/justfile b/codex-rs/justfile index 4829da3606..9ddc4a37aa 100644 --- a/codex-rs/justfile +++ b/codex-rs/justfile @@ -5,6 +5,7 @@ help: just -l # `codex` +alias c := codex codex *args: cargo run --bin codex -- "$@" diff --git a/codex-rs/tui/src/history_cell.rs b/codex-rs/tui/src/history_cell.rs index de719ca5c7..d374aaa104 100644 --- a/codex-rs/tui/src/history_cell.rs +++ b/codex-rs/tui/src/history_cell.rs @@ -850,7 +850,7 @@ pub(crate) fn new_mcp_tools_output( lines.push(vec![" • Server: ".into(), server.clone().into()].into()); match &cfg.transport { - McpServerTransportConfig::Stdio { command, args } => { + McpServerTransportConfig::Stdio { command, args, .. } => { let args_suffix = if args.is_empty() { String::new() } else { From a44df7e4b12bf2a5d51726f62af030be480c3c02 Mon Sep 17 00:00:00 2001 From: Gabriel Peal Date: Fri, 26 Sep 2025 11:55:01 -0700 Subject: [PATCH 18/23] Flatten config --- codex-rs/core/src/config_types.rs | 563 ++++++-------------- codex-rs/core/src/mcp_connection_manager.rs | 5 +- 2 files changed, 157 insertions(+), 411 deletions(-) diff --git a/codex-rs/core/src/config_types.rs b/codex-rs/core/src/config_types.rs index c04c32b8a3..7c67bd86f0 100644 --- a/codex-rs/core/src/config_types.rs +++ b/codex-rs/core/src/config_types.rs @@ -9,32 +9,11 @@ use std::time::Duration; use wildmatch::WildMatchPattern; use serde::Deserialize; -use serde::Deserializer; use serde::Serialize; -use serde::de::Error as SerdeError; #[derive(Serialize, Deserialize, Debug, Clone, PartialEq)] -#[serde(tag = "type", rename_all = "snake_case")] -pub enum McpServerTransportConfig { - /// https://modelcontextprotocol.io/specification/2025-06-18/basic/transports#stdio - Stdio { - #[serde(alias = "commands")] - command: String, - #[serde(default)] - args: Vec, - #[serde(default, skip_serializing_if = "Option::is_none")] - env: Option>, - }, - /// https://modelcontextprotocol.io/specification/2025-06-18/basic/transports#streamable-http - StreamableHttp { - url: String, - #[serde(default, skip_serializing_if = "Option::is_none")] - bearer_token: Option, - }, -} - -#[derive(Serialize, Debug, Clone, PartialEq)] pub struct McpServerConfig { + #[serde(flatten)] pub transport: McpServerTransportConfig, /// Startup timeout in seconds for initializing MCP server & initially listing tools. @@ -46,395 +25,27 @@ pub struct McpServerConfig { pub startup_timeout_sec: Option, /// Default timeout for MCP tool calls initiated via this server. - #[serde( - default, - with = "option_duration_secs", - skip_serializing_if = "Option::is_none" - )] + #[serde(default, with = "option_duration_secs")] pub tool_timeout_sec: Option, } -impl<'de> Deserialize<'de> for McpServerConfig { - fn deserialize(deserializer: D) -> Result - where - D: Deserializer<'de>, - { - #[derive(Deserialize)] - struct RawMcpServerConfig { - #[serde(default)] - transport: Option, - #[serde(default)] - command: Option, - #[serde(default)] - args: Vec, - #[serde(default)] - env: Option>, - #[serde(default)] - url: Option, - #[serde(default)] - bearer_token: Option, - #[serde(default)] - startup_timeout_sec: Option, - #[serde(default)] - startup_timeout_ms: Option, - #[serde(default, with = "option_duration_secs")] - tool_timeout_sec: Option, - } - - let raw = RawMcpServerConfig::deserialize(deserializer)?; - let RawMcpServerConfig { - transport: raw_transport, - command, - args, - env: top_level_env, - url, - bearer_token, - startup_timeout_sec, - startup_timeout_ms, - tool_timeout_sec, - } = raw; - - if raw_transport.is_some() - && (command.is_some() || !args.is_empty() || url.is_some() || bearer_token.is_some()) - { - return Err(SerdeError::custom( - "`transport` must not be combined with legacy MCP transport fields", - )); - } - - let startup_timeout_sec = match (startup_timeout_sec, startup_timeout_ms) { - (Some(sec), _) => { - let duration = Duration::try_from_secs_f64(sec).map_err(SerdeError::custom)?; - Some(duration) - } - (None, Some(ms)) => Some(Duration::from_millis(ms)), - (None, None) => None, - }; - - let mut top_level_env = top_level_env; - - let transport = if let Some(transport) = raw_transport { - match transport { - McpServerTransportConfig::Stdio { command, args, env } => { - let command = normalize_string_option(command).ok_or_else(|| { - SerdeError::custom( - "MCP server config `transport.stdio.command` must be non-empty", - ) - })?; - let env = normalize_env_option(env); - if top_level_env.is_some() && env.is_some() { - return Err(SerdeError::custom( - "MCP server config must not set both `env` and `transport.stdio.env`", - )); - } - McpServerTransportConfig::Stdio { command, args, env } - } - McpServerTransportConfig::StreamableHttp { url, bearer_token } => { - let url = normalize_string_option(url).ok_or_else(|| { - SerdeError::custom( - "MCP server config `transport.streamable_http.url` must be non-empty", - ) - })?; - let bearer_token = bearer_token.and_then(normalize_string_option); - McpServerTransportConfig::StreamableHttp { url, bearer_token } - } - } - } else { - let command = command.and_then(normalize_string_option); - let url = url.and_then(normalize_string_option); - - match (command, url) { - (Some(command), None) => McpServerTransportConfig::Stdio { - command, - args, - env: normalize_env_option(top_level_env.take()), - }, - (None, Some(url)) => { - if !args.is_empty() { - return Err(SerdeError::custom( - "`args` is not supported when configuring MCP servers via `url`", - )); - } - if top_level_env.as_ref().is_some_and(|env| !env.is_empty()) { - return Err(SerdeError::custom( - "`env` is not supported when configuring MCP servers via `url`", - )); - } - - let bearer_token = bearer_token.and_then(normalize_string_option); - McpServerTransportConfig::StreamableHttp { url, bearer_token } - } - (Some(_), Some(_)) => { - return Err(SerdeError::custom( - "MCP server config must not set both `command` and `url`", - )); - } - (None, None) => { - return Err(SerdeError::custom( - "MCP server config must set either `command` or `url` or use `transport`", - )); - } - } - }; - - if top_level_env.as_ref().is_some_and(|env| !env.is_empty()) - && matches!(transport, McpServerTransportConfig::StreamableHttp { .. }) - { - return Err(SerdeError::custom( - "`env` is not supported when configuring MCP servers via `url`", - )); - } - - let transport = match (transport, top_level_env.take()) { - (mut transport @ McpServerTransportConfig::Stdio { .. }, Some(env)) => { - if !env.is_empty() { - match &mut transport { - McpServerTransportConfig::Stdio { - env: target_env, .. - } => { - if target_env.is_some() { - return Err(SerdeError::custom( - "MCP server config must not set both `env` and `transport.stdio.env`", - )); - } - *target_env = Some(env); - } - McpServerTransportConfig::StreamableHttp { .. } => unreachable!(), - } - } - transport - } - (transport, _) => transport, - }; - - Ok(Self { - transport, - startup_timeout_sec, - tool_timeout_sec, - }) - } -} - -fn normalize_string_option(value: String) -> Option { - let trimmed = value.trim(); - if trimmed.is_empty() { - None - } else { - Some(trimmed.to_string()) - } -} - -fn normalize_env_option(env: Option>) -> Option> { - env.and_then(|env| if env.is_empty() { None } else { Some(env) }) -} - -#[cfg(test)] -mod tests { - use super::*; - use pretty_assertions::assert_eq; - - #[test] - fn deserialize_legacy_command_server_config() { - let cfg: McpServerConfig = toml::from_str( - r#" - command = "echo" - args = ["hello", "world"] - "#, - ) - .expect("should deserialize command config"); - - match cfg.transport { - McpServerTransportConfig::Stdio { command, args, env } => { - assert_eq!(command, "echo"); - assert_eq!(args, vec!["hello", "world"]); - assert!(env.is_none()); - } - other => panic!("unexpected transport: {other:?}"), - } - } - - #[test] - fn deserialize_legacy_command_server_config_with_env() { - let cfg: McpServerConfig = toml::from_str( - r#" - command = "echo" - - [env] - FOO = "BAR" - "#, - ) - .expect("should deserialize command config with env"); - - match cfg.transport { - McpServerTransportConfig::Stdio { command, args, env } => { - assert_eq!(command, "echo"); - assert!(args.is_empty()); - let env = env.expect("env should be set"); - assert_eq!(env.get("FOO"), Some(&"BAR".to_string())); - } - other => panic!("unexpected transport: {other:?}"), - } - } - - #[test] - fn deserialize_transport_stdio_config() { - let cfg: McpServerConfig = toml::from_str( - r#" - [transport] - type = "stdio" - command = "echo" - args = ["hi"] - "#, - ) - .expect("should deserialize stdio transport"); - - match cfg.transport { - McpServerTransportConfig::Stdio { command, args, env } => { - assert_eq!(command, "echo"); - assert_eq!(args, vec!["hi"]); - assert!(env.is_none()); - } - other => panic!("unexpected transport: {other:?}"), - } - } - - #[test] - fn deserialize_transport_stdio_config_with_env() { - let cfg: McpServerConfig = toml::from_str( - r#" - [transport] - type = "stdio" - command = "echo" - - [transport.env] - FOO = "BAR" - "#, - ) - .expect("should deserialize stdio transport with env"); - - match cfg.transport { - McpServerTransportConfig::Stdio { command, args, env } => { - assert_eq!(command, "echo"); - assert!(args.is_empty()); - let env = env.expect("env should be present"); - assert_eq!(env.get("FOO"), Some(&"BAR".to_string())); - } - other => panic!("unexpected transport: {other:?}"), - } - } - - #[test] - fn deserialize_stdio_accepts_commands_alias() { - let cfg: McpServerConfig = toml::from_str( - r#" - [transport] - type = "stdio" - commands = "echo" - "#, - ) - .expect("should deserialize stdio transport with commands alias"); - - match cfg.transport { - McpServerTransportConfig::Stdio { command, args, env } => { - assert_eq!(command, "echo"); - assert!(args.is_empty()); - assert!(env.is_none()); - } - other => panic!("unexpected transport: {other:?}"), - } - } - - #[test] - fn deserialize_streamable_http_server_config() { - let cfg: McpServerConfig = toml::from_str( - r#" - url = "https://example.com/mcp" - bearer_token = "secret" - "#, - ) - .expect("should deserialize http config"); - - match cfg.transport { - McpServerTransportConfig::StreamableHttp { url, bearer_token } => { - assert_eq!(url, "https://example.com/mcp"); - assert_eq!(bearer_token.as_deref(), Some("secret")); - } - other => panic!("unexpected transport: {other:?}"), - } - } - - #[test] - fn deserialize_rejects_invalid_transport_combo() { - let err = toml::from_str::( - r#" - command = "echo" - url = "https://example.com" - "#, - ) - .expect_err("should reject command+url"); - - assert!(err.to_string().contains("must not set both")); - } - - #[test] - fn deserialize_rejects_mixed_transport_fields() { - let err = toml::from_str::( - r#" - [transport] - type = "stdio" - command = "echo" - args = [] - - command = "echo" - "#, - ) - .expect_err("should reject mixing transport and legacy fields"); - - assert!( - err.to_string() - .contains("must not be combined with legacy MCP transport fields") - ); - } - - #[test] - fn deserialize_rejects_env_for_http_transport() { - let err = toml::from_str::( - r#" - [transport] - type = "streamable_http" - url = "https://example.com" - - [env] - FOO = "BAR" - "#, - ) - .expect_err("should reject env for http transport"); - - assert!(err.to_string().contains("`env` is not supported")); - } - - #[test] - fn deserialize_rejects_duplicate_env_definitions() { - let err = toml::from_str::( - r#" - [transport] - type = "stdio" - command = "echo" - - [transport.env] - FOO = "BAR" - - [env] - BAZ = "QUX" - "#, - ) - .expect_err("should reject duplicate env definitions"); - - assert!( - err.to_string() - .contains("MCP server config must not set both `env` and `transport.stdio.env`") - ); - } +#[derive(Serialize, Deserialize, Debug, Clone, PartialEq)] +#[serde(untagged, deny_unknown_fields, rename_all = "snake_case")] +pub enum McpServerTransportConfig { + /// https://modelcontextprotocol.io/specification/2025-06-18/basic/transports#stdio + Stdio { + command: String, + #[serde(default)] + args: Vec, + #[serde(default, skip_serializing_if = "Option::is_none")] + env: Option>, + }, + /// https://modelcontextprotocol.io/specification/2025-06-18/basic/transports#streamable-http + StreamableHttp { + url: String, + #[serde(default, skip_serializing_if = "Option::is_none")] + bearer_token: Option, + }, } mod option_duration_secs { @@ -663,3 +274,139 @@ pub enum ReasoningSummaryFormat { None, Experimental, } + +#[cfg(test)] +mod tests { + use super::*; + use pretty_assertions::assert_eq; + + #[test] + fn deserialize_stdio_command_server_config() { + let cfg: McpServerConfig = toml::from_str( + r#" + command = "echo" + "#, + ) + .expect("should deserialize command config"); + + assert_eq!( + cfg.transport, + McpServerTransportConfig::Stdio { + command: "echo".to_string(), + args: vec![], + env: None + } + ); + } + + #[test] + fn deserialize_stdio_command_server_config_with_args() { + let cfg: McpServerConfig = toml::from_str( + r#" + command = "echo" + args = ["hello", "world"] + "#, + ) + .expect("should deserialize command config"); + + assert_eq!( + cfg.transport, + McpServerTransportConfig::Stdio { + command: "echo".to_string(), + args: vec!["hello".to_string(), "world".to_string()], + env: None + } + ); + } + + #[test] + fn deserialize_stdio_command_server_config_with_arg_with_args_and_env() { + let cfg: McpServerConfig = toml::from_str( + r#" + command = "echo" + args = ["hello", "world"] + env = { "FOO" = "BAR" } + "#, + ) + .expect("should deserialize command config"); + + assert_eq!( + cfg.transport, + McpServerTransportConfig::Stdio { + command: "echo".to_string(), + args: vec!["hello".to_string(), "world".to_string()], + env: Some(HashMap::from([("FOO".to_string(), "BAR".to_string())])) + } + ); + } + + #[test] + fn deserialize_streamable_http_server_config() { + let cfg: McpServerConfig = toml::from_str( + r#" + url = "https://example.com/mcp" + "#, + ) + .expect("should deserialize http config"); + + assert_eq!( + cfg.transport, + McpServerTransportConfig::StreamableHttp { + url: "https://example.com/mcp".to_string(), + bearer_token: None + } + ); + } + + #[test] + fn deserialize_streamable_http_server_config_with_bearer_token() { + let cfg: McpServerConfig = toml::from_str( + r#" + url = "https://example.com/mcp" + bearer_token = "secret" + "#, + ) + .expect("should deserialize http config"); + + assert_eq!( + cfg.transport, + McpServerTransportConfig::StreamableHttp { + url: "https://example.com/mcp".to_string(), + bearer_token: Some("secret".to_string()) + } + ); + } + + #[test] + fn deserialize_rejects_command_and_url() { + toml::from_str::( + r#" + command = "echo" + url = "https://example.com" + "#, + ) + .expect_err("should reject command+url"); + } + + #[test] + fn deserialize_rejects_env_for_http_transport() { + toml::from_str::( + r#" + url = "https://example.com" + env = { "FOO" = "BAR" } + "#, + ) + .expect_err("should reject env for http transport"); + } + + #[test] + fn deserialize_rejects_bearer_token_for_stdio_transport() { + toml::from_str::( + r#" + command = "echo" + bearer_token = "secret" + "#, + ) + .expect_err("should reject bearer token for stdio transport"); + } +} diff --git a/codex-rs/core/src/mcp_connection_manager.rs b/codex-rs/core/src/mcp_connection_manager.rs index b5427ba1aa..dc54509825 100644 --- a/codex-rs/core/src/mcp_connection_manager.rs +++ b/codex-rs/core/src/mcp_connection_manager.rs @@ -257,7 +257,6 @@ impl McpConnectionManager { startup_timeout, ) .await - .map(|c| (c, startup_timeout)) } McpServerTransportConfig::StreamableHttp { url, bearer_token } => { McpClientAdapter::new_streamable_http_client( @@ -267,9 +266,9 @@ impl McpConnectionManager { startup_timeout, ) .await - .map(|c| (c, startup_timeout)) } - }; + } + .map(|c| (c, startup_timeout)); ((server_name, tool_timeout), client) }); From 2b0e706d517d5b450f42dbb6724d5f5286ee8e4a Mon Sep 17 00:00:00 2001 From: Gabriel Peal Date: Fri, 26 Sep 2025 14:48:00 -0700 Subject: [PATCH 19/23] Add Bearer header --- codex-rs/rmcp-client/src/rmcp_client.rs | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/codex-rs/rmcp-client/src/rmcp_client.rs b/codex-rs/rmcp-client/src/rmcp_client.rs index dca77d8b17..e127029472 100644 --- a/codex-rs/rmcp-client/src/rmcp_client.rs +++ b/codex-rs/rmcp-client/src/rmcp_client.rs @@ -106,7 +106,7 @@ impl RmcpClient { pub fn new_streamable_http_client(url: String, bearer_token: Option) -> Result { let mut config = StreamableHttpClientTransportConfig::with_uri(url); if let Some(token) = bearer_token { - config = config.auth_header(token); + config = config.auth_header(format!("Bearer {token}")); } let transport = StreamableHttpClientTransport::from_config(config); From 50fcc9ab30b698ceb38f31b96194c11390aa7e36 Mon Sep 17 00:00:00 2001 From: Gabriel Peal Date: Fri, 26 Sep 2025 14:59:44 -0700 Subject: [PATCH 20/23] Fixed tool --- codex-rs/core/src/config_types.rs | 92 ++++++++++++++++++++++++++++++- 1 file changed, 91 insertions(+), 1 deletion(-) diff --git a/codex-rs/core/src/config_types.rs b/codex-rs/core/src/config_types.rs index 7c67bd86f0..d23f18e8fc 100644 --- a/codex-rs/core/src/config_types.rs +++ b/codex-rs/core/src/config_types.rs @@ -3,6 +3,7 @@ // Note this file should generally be restricted to simple struct/enum // definitions that do not contain business logic. +use serde::Deserializer; use std::collections::HashMap; use std::path::PathBuf; use std::time::Duration; @@ -10,8 +11,9 @@ use wildmatch::WildMatchPattern; use serde::Deserialize; use serde::Serialize; +use serde::de::Error as SerdeError; -#[derive(Serialize, Deserialize, Debug, Clone, PartialEq)] +#[derive(Serialize, Debug, Clone, PartialEq)] pub struct McpServerConfig { #[serde(flatten)] pub transport: McpServerTransportConfig, @@ -29,6 +31,94 @@ pub struct McpServerConfig { pub tool_timeout_sec: Option, } +impl<'de> Deserialize<'de> for McpServerConfig { + fn deserialize(deserializer: D) -> Result + where + D: Deserializer<'de>, + { + #[derive(Deserialize)] + struct RawMcpServerConfig { + command: Option, + #[serde(default)] + args: Option>, + #[serde(default)] + env: Option>>, + + url: Option, + bearer_token: Option, + + #[serde(default)] + startup_timeout_sec: Option, + #[serde(default)] + startup_timeout_ms: Option, + #[serde(default, with = "option_duration_secs")] + tool_timeout_sec: Option, + } + + let raw = RawMcpServerConfig::deserialize(deserializer)?; + + let startup_timeout_sec = match (raw.startup_timeout_sec, raw.startup_timeout_ms) { + (Some(sec), _) => { + let duration = Duration::try_from_secs_f64(sec).map_err(SerdeError::custom)?; + Some(duration) + } + (None, Some(ms)) => Some(Duration::from_millis(ms)), + (None, None) => None, + }; + + fn throw_if_set(transport: &str, field: &str, value: Option<&T>) -> Result<(), E> + where + E: SerdeError, + { + if value.is_none() { + return Ok(()); + } + Err(E::custom(format!( + "{field} is not supported for {transport}", + ))) + } + + let transport = match raw { + RawMcpServerConfig { + command: Some(command), + args, + env, + url, + bearer_token, + .. + } => { + throw_if_set("stdio", "url", url.as_ref())?; + throw_if_set("stdio", "bearer_token", bearer_token.as_ref())?; + McpServerTransportConfig::Stdio { + command, + args: args.unwrap_or_default(), + env: env.unwrap_or_default(), + } + } + RawMcpServerConfig { + url: Some(url), + bearer_token, + command, + args, + env, + .. + } => { + throw_if_set("streamable_http", "command", command.as_ref())?; + throw_if_set("streamable_http", "args", args.as_ref())?; + throw_if_set("streamable_http", "env", env.as_ref())?; + McpServerTransportConfig::StreamableHttp { url, bearer_token } + } + _ => return Err(SerdeError::custom("invalid transport")), + }; + + Ok(Self { + transport, + startup_timeout_sec, + tool_timeout_sec: raw.tool_timeout_sec, + }) + } +} + #[derive(Serialize, Deserialize, Debug, Clone, PartialEq)] #[serde(untagged, deny_unknown_fields, rename_all = "snake_case")] pub enum McpServerTransportConfig { From 1794dd8d09c1f428bb052c53b43c7e26d01cf0ed Mon Sep 17 00:00:00 2001 From: Gabriel Peal Date: Fri, 26 Sep 2025 16:04:17 -0700 Subject: [PATCH 21/23] Removed log --- codex-rs/core/src/mcp_connection_manager.rs | 2 -- 1 file changed, 2 deletions(-) diff --git a/codex-rs/core/src/mcp_connection_manager.rs b/codex-rs/core/src/mcp_connection_manager.rs index dc54509825..015ca8c6bb 100644 --- a/codex-rs/core/src/mcp_connection_manager.rs +++ b/codex-rs/core/src/mcp_connection_manager.rs @@ -188,8 +188,6 @@ impl McpConnectionManager { return Ok((Self::default(), ClientStartErrors::default())); } - tracing::error!("new mcp_servers: {mcp_servers:?} use_rmcp_client: {use_rmcp_client}"); - // Launch all configured servers concurrently. let mut join_set = JoinSet::new(); let mut errors = ClientStartErrors::new(); From 268ba776257c3bdb3956a1d75ef0bf308fa69fda Mon Sep 17 00:00:00 2001 From: Gabriel Peal Date: Fri, 26 Sep 2025 17:41:04 -0700 Subject: [PATCH 22/23] Cleanup --- codex-rs/cli/src/mcp_cmd.rs | 19 ---- codex-rs/core/src/config.rs | 128 +++++++++++++++++++++++ codex-rs/core/src/config_types.rs | 7 +- codex-rs/core/tests/suite/rmcp_client.rs | 23 +++- 4 files changed, 151 insertions(+), 26 deletions(-) diff --git a/codex-rs/cli/src/mcp_cmd.rs b/codex-rs/cli/src/mcp_cmd.rs index 9c5923c8b5..0cb448d8be 100644 --- a/codex-rs/cli/src/mcp_cmd.rs +++ b/codex-rs/cli/src/mcp_cmd.rs @@ -1,4 +1,3 @@ -use std::collections::BTreeMap; use std::collections::HashMap; use std::path::PathBuf; @@ -204,14 +203,6 @@ fn run_list(config_overrides: &CliConfigOverrides, list_args: ListArgs) -> Resul let json_entries: Vec<_> = entries .into_iter() .map(|(name, cfg)| { - let env = match &cfg.transport { - McpServerTransportConfig::Stdio { env, .. } => env.as_ref().map(|env| { - env.iter() - .map(|(k, v)| (k.clone(), v.clone())) - .collect::>() - }), - McpServerTransportConfig::StreamableHttp { .. } => None, - }; let transport = match &cfg.transport { McpServerTransportConfig::Stdio { command, args, env } => serde_json::json!({ "type": "stdio", @@ -231,7 +222,6 @@ fn run_list(config_overrides: &CliConfigOverrides, list_args: ListArgs) -> Resul serde_json::json!({ "name": name, "transport": transport, - "env": env, "startup_timeout_sec": cfg .startup_timeout_sec .map(|timeout| timeout.as_secs_f64()), @@ -371,14 +361,6 @@ fn run_get(config_overrides: &CliConfigOverrides, get_args: GetArgs) -> Result<( }; if get_args.json { - let env = match &server.transport { - McpServerTransportConfig::Stdio { env, .. } => env.as_ref().map(|env| { - env.iter() - .map(|(k, v)| (k.clone(), v.clone())) - .collect::>() - }), - McpServerTransportConfig::StreamableHttp { .. } => None, - }; let transport = match &server.transport { McpServerTransportConfig::Stdio { command, args, env } => serde_json::json!({ "type": "stdio", @@ -395,7 +377,6 @@ fn run_get(config_overrides: &CliConfigOverrides, get_args: GetArgs) -> Result<( let output = serde_json::to_string_pretty(&serde_json::json!({ "name": get_args.name, "transport": transport, - "env": env, "startup_timeout_sec": server .startup_timeout_sec .map(|timeout| timeout.as_secs_f64()), diff --git a/codex-rs/core/src/config.rs b/codex-rs/core/src/config.rs index 864e9f9b17..292b9f7b51 100644 --- a/codex-rs/core/src/config.rs +++ b/codex-rs/core/src/config.rs @@ -1361,6 +1361,134 @@ startup_timeout_ms = 2500 Ok(()) } + #[test] + fn write_global_mcp_servers_serializes_env_sorted() -> anyhow::Result<()> { + let codex_home = TempDir::new()?; + + let servers = BTreeMap::from([( + "docs".to_string(), + McpServerConfig { + transport: McpServerTransportConfig::Stdio { + command: "docs-server".to_string(), + args: vec!["--verbose".to_string()], + env: Some(HashMap::from([ + ("ZIG_VAR".to_string(), "3".to_string()), + ("ALPHA_VAR".to_string(), "1".to_string()), + ])), + }, + startup_timeout_sec: None, + tool_timeout_sec: None, + }, + )]); + + write_global_mcp_servers(codex_home.path(), &servers)?; + + let config_path = codex_home.path().join(CONFIG_TOML_FILE); + let serialized = std::fs::read_to_string(&config_path)?; + assert_eq!( + serialized, + r#"[mcp_servers.docs] +command = "docs-server" +args = ["--verbose"] + +[mcp_servers.docs.env] +ALPHA_VAR = "1" +ZIG_VAR = "3" +"# + ); + + let loaded = load_global_mcp_servers(codex_home.path())?; + let docs = loaded.get("docs").expect("docs entry"); + match &docs.transport { + McpServerTransportConfig::Stdio { command, args, env } => { + assert_eq!(command, "docs-server"); + assert_eq!(args, &vec!["--verbose".to_string()]); + let env = env + .as_ref() + .expect("env should be preserved for stdio transport"); + assert_eq!(env.get("ALPHA_VAR"), Some(&"1".to_string())); + assert_eq!(env.get("ZIG_VAR"), Some(&"3".to_string())); + } + other => panic!("unexpected transport {other:?}"), + } + + Ok(()) + } + + #[test] + fn write_global_mcp_servers_serializes_streamable_http() -> anyhow::Result<()> { + let codex_home = TempDir::new()?; + + let mut servers = BTreeMap::from([( + "docs".to_string(), + McpServerConfig { + transport: McpServerTransportConfig::StreamableHttp { + url: "https://example.com/mcp".to_string(), + bearer_token: Some("secret-token".to_string()), + }, + startup_timeout_sec: Some(Duration::from_secs(2)), + tool_timeout_sec: None, + }, + )]); + + write_global_mcp_servers(codex_home.path(), &servers)?; + + let config_path = codex_home.path().join(CONFIG_TOML_FILE); + let serialized = std::fs::read_to_string(&config_path)?; + assert_eq!( + serialized, + r#"[mcp_servers.docs] +url = "https://example.com/mcp" +bearer_token = "secret-token" +startup_timeout_sec = 2.0 +"# + ); + + let loaded = load_global_mcp_servers(codex_home.path())?; + let docs = loaded.get("docs").expect("docs entry"); + match &docs.transport { + McpServerTransportConfig::StreamableHttp { url, bearer_token } => { + assert_eq!(url, "https://example.com/mcp"); + assert_eq!(bearer_token.as_deref(), Some("secret-token")); + } + other => panic!("unexpected transport {other:?}"), + } + assert_eq!(docs.startup_timeout_sec, Some(Duration::from_secs(2))); + + servers.insert( + "docs".to_string(), + McpServerConfig { + transport: McpServerTransportConfig::StreamableHttp { + url: "https://example.com/mcp".to_string(), + bearer_token: None, + }, + startup_timeout_sec: None, + tool_timeout_sec: None, + }, + ); + write_global_mcp_servers(codex_home.path(), &servers)?; + + let serialized = std::fs::read_to_string(&config_path)?; + assert_eq!( + serialized, + r#"[mcp_servers.docs] +url = "https://example.com/mcp" +"# + ); + + let loaded = load_global_mcp_servers(codex_home.path())?; + let docs = loaded.get("docs").expect("docs entry"); + match &docs.transport { + McpServerTransportConfig::StreamableHttp { url, bearer_token } => { + assert_eq!(url, "https://example.com/mcp"); + assert!(bearer_token.is_none()); + } + other => panic!("unexpected transport {other:?}"), + } + + Ok(()) + } + #[tokio::test] async fn persist_model_selection_updates_defaults() -> anyhow::Result<()> { let codex_home = TempDir::new()?; diff --git a/codex-rs/core/src/config_types.rs b/codex-rs/core/src/config_types.rs index d23f18e8fc..283ae3a480 100644 --- a/codex-rs/core/src/config_types.rs +++ b/codex-rs/core/src/config_types.rs @@ -42,7 +42,7 @@ impl<'de> Deserialize<'de> for McpServerConfig { #[serde(default)] args: Option>, #[serde(default)] - env: Option>>, + env: Option>, url: Option, bearer_token: Option, @@ -92,7 +92,7 @@ impl<'de> Deserialize<'de> for McpServerConfig { McpServerTransportConfig::Stdio { command, args: args.unwrap_or_default(), - env: env.unwrap_or_default(), + env, } } RawMcpServerConfig { @@ -133,6 +133,9 @@ pub enum McpServerTransportConfig { /// https://modelcontextprotocol.io/specification/2025-06-18/basic/transports#streamable-http StreamableHttp { url: String, + /// A plain text bearer token to use for authentication. + /// This bearer token will be included in the HTTP request header as an `Authorization: Bearer ` header. + /// This should be used with caution because it lives on disk in clear text. #[serde(default, skip_serializing_if = "Option::is_none")] bearer_token: Option, }, diff --git a/codex-rs/core/tests/suite/rmcp_client.rs b/codex-rs/core/tests/suite/rmcp_client.rs index 89e4622bf4..aed84ec603 100644 --- a/codex-rs/core/tests/suite/rmcp_client.rs +++ b/codex-rs/core/tests/suite/rmcp_client.rs @@ -26,7 +26,7 @@ use tokio::time::Instant; use tokio::time::sleep; use wiremock::matchers::any; -#[tokio::test(flavor = "multi_thread", worker_threads = 2)] +#[tokio::test(flavor = "multi_thread", worker_threads = 1)] async fn stdio_server_round_trip() -> anyhow::Result<()> { skip_if_no_network!(Ok(())); @@ -163,7 +163,7 @@ async fn stdio_server_round_trip() -> anyhow::Result<()> { Ok(()) } -#[tokio::test(flavor = "multi_thread", worker_threads = 2)] +#[tokio::test(flavor = "multi_thread", worker_threads = 1)] async fn streamable_http_tool_call_round_trip() -> anyhow::Result<()> { skip_if_no_network!(Ok(())); @@ -342,15 +342,28 @@ async fn wait_for_streamable_http_server( )); } - match TcpStream::connect(address).await { - Ok(_) => return Ok(()), - Err(error) => { + let remaining = deadline.saturating_duration_since(Instant::now()); + + if remaining.is_zero() { + return Err(anyhow::anyhow!( + "timed out waiting for streamable HTTP server at {address}: deadline reached" + )); + } + + match tokio::time::timeout(remaining, TcpStream::connect(address)).await { + Ok(Ok(_)) => return Ok(()), + Ok(Err(error)) => { if Instant::now() >= deadline { return Err(anyhow::anyhow!( "timed out waiting for streamable HTTP server at {address}: {error}" )); } } + Err(_) => { + return Err(anyhow::anyhow!( + "timed out waiting for streamable HTTP server at {address}: connect call timed out" + )); + } } sleep(Duration::from_millis(50)).await; From 2c9f92383f85e9f1faa39a42d5618d808a05cb83 Mon Sep 17 00:00:00 2001 From: Gabriel Peal Date: Fri, 26 Sep 2025 18:09:50 -0700 Subject: [PATCH 23/23] Fixed test --- codex-rs/cli/tests/mcp_list.rs | 58 ++++++++++++---------------------- 1 file changed, 20 insertions(+), 38 deletions(-) diff --git a/codex-rs/cli/tests/mcp_list.rs b/codex-rs/cli/tests/mcp_list.rs index b15d11faa8..6c83de19fa 100644 --- a/codex-rs/cli/tests/mcp_list.rs +++ b/codex-rs/cli/tests/mcp_list.rs @@ -4,6 +4,7 @@ use anyhow::Result; use predicates::str::contains; use pretty_assertions::assert_eq; use serde_json::Value as JsonValue; +use serde_json::json; use tempfile::TempDir; fn codex_command(codex_home: &Path) -> Result { @@ -58,48 +59,29 @@ fn list_and_get_render_expected_output() -> Result<()> { assert!(json_output.status.success()); let stdout = String::from_utf8(json_output.stdout)?; let parsed: JsonValue = serde_json::from_str(&stdout)?; - let array = parsed.as_array().expect("expected array"); - assert_eq!(array.len(), 1); - let entry = &array[0]; - assert_eq!(entry.get("name"), Some(&JsonValue::String("docs".into()))); - let transport = entry - .get("transport") - .and_then(|value| value.as_object()) - .expect("transport object"); assert_eq!( - transport.get("type"), - Some(&JsonValue::String("stdio".into())) - ); - assert_eq!( - transport.get("command"), - Some(&JsonValue::String("docs-server".into())) - ); - let transport_env = transport - .get("env") - .and_then(|value| value.as_object()) - .expect("transport env map"); - assert_eq!( - transport_env.get("TOKEN"), - Some(&JsonValue::String("secret".into())) - ); - let args = transport - .get("args") - .and_then(|value| value.as_array()) - .expect("transport args array"); - assert_eq!( - args, - &vec![ - JsonValue::String("--port".into()), - JsonValue::String("4000".into()) + parsed, + json!([ + { + "name": "docs", + "transport": { + "type": "stdio", + "command": "docs-server", + "args": [ + "--port", + "4000" + ], + "env": { + "TOKEN": "secret" + } + }, + "startup_timeout_sec": null, + "tool_timeout_sec": null + } ] + ) ); - let env = entry - .get("env") - .and_then(|v| v.as_object()) - .expect("env map"); - assert_eq!(env.get("TOKEN"), Some(&JsonValue::String("secret".into()))); - let mut get_cmd = codex_command(codex_home.path())?; let get_output = get_cmd.args(["mcp", "get", "docs"]).output()?; assert!(get_output.status.success());