diff --git a/codex-rs/Cargo.lock b/codex-rs/Cargo.lock index ac792eaf3b..eb97878072 100644 --- a/codex-rs/Cargo.lock +++ b/codex-rs/Cargo.lock @@ -333,6 +333,54 @@ version = "1.5.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "c08606f8c3cbf4ce6ec8e28fb0014a2c086708fe954eaa885384a6165172e7e8" +[[package]] +name = "axum" +version = "0.8.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "021e862c184ae977658b36c4500f7feac3221ca5da43e3f25bd04ab6c79a29b5" +dependencies = [ + "axum-core", + "bytes", + "futures-util", + "http", + "http-body", + "http-body-util", + "hyper", + "hyper-util", + "itoa", + "matchit", + "memchr", + "mime", + "percent-encoding", + "pin-project-lite", + "rustversion", + "serde", + "sync_wrapper", + "tokio", + "tower", + "tower-layer", + "tower-service", +] + +[[package]] +name = "axum-core" +version = "0.5.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "68464cd0412f486726fb3373129ef5d2993f90c34bc2bc1c1e9943b2f4fc7ca6" +dependencies = [ + "bytes", + "futures-core", + "http", + "http-body", + "http-body-util", + "mime", + "pin-project-lite", + "rustversion", + "sync_wrapper", + "tower-layer", + "tower-service", +] + [[package]] name = "backtrace" version = "0.3.75" @@ -956,8 +1004,11 @@ name = "codex-rmcp-client" version = "0.0.0" dependencies = [ "anyhow", + "axum", + "futures", "mcp-types", "pretty_assertions", + "reqwest", "rmcp", "serde", "serde_json", @@ -2883,6 +2934,12 @@ dependencies = [ "regex-automata", ] +[[package]] +name = "matchit" +version = "0.8.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "47e1ffaa40ddd1f3ed91f717a33c8c0ee23fff369e3aa8772b9605cc1d22f4c3" + [[package]] name = "mcp-types" version = "0.0.0" @@ -3669,7 +3726,7 @@ dependencies = [ "once_cell", "socket2", "tracing", - "windows-sys 0.52.0", + "windows-sys 0.60.2", ] [[package]] @@ -3907,20 +3964,29 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "534fd1cd0601e798ac30545ff2b7f4a62c6f14edd4aaed1cc5eb1e85f69f09af" dependencies = [ "base64", + "bytes", "chrono", "futures", + "http", + "http-body", + "http-body-util", "paste", "pin-project-lite", "process-wrap", + "rand", + "reqwest", "rmcp-macros", "schemars 1.0.4", "serde", "serde_json", + "sse-stream", "thiserror 2.0.16", "tokio", "tokio-stream", "tokio-util", + "tower-service", "tracing", + "uuid", ] [[package]] @@ -4468,6 +4534,19 @@ dependencies = [ "windows-sys 0.59.0", ] +[[package]] +name = "sse-stream" +version = "0.2.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "eb4dc4d33c68ec1f27d386b5610a351922656e1fdf5c05bbaad930cd1519479a" +dependencies = [ + "bytes", + "futures-util", + "http-body", + "http-body-util", + "pin-project-lite", +] + [[package]] name = "stable_deref_trait" version = "1.2.0" diff --git a/codex-rs/cli/src/mcp_cmd.rs b/codex-rs/cli/src/mcp_cmd.rs index 465de71aac..0cb448d8be 100644 --- a/codex-rs/cli/src/mcp_cmd.rs +++ b/codex-rs/cli/src/mcp_cmd.rs @@ -1,4 +1,3 @@ -use std::collections::BTreeMap; use std::collections::HashMap; use std::path::PathBuf; @@ -13,6 +12,7 @@ use codex_core::config::find_codex_home; use codex_core::config::load_global_mcp_servers; use codex_core::config::write_global_mcp_servers; use codex_core::config_types::McpServerConfig; +use codex_core::config_types::McpServerTransportConfig; /// [experimental] Launch Codex as an MCP server or manage configured MCP servers. /// @@ -145,9 +145,11 @@ fn run_add(config_overrides: &CliConfigOverrides, add_args: AddArgs) -> Result<( .with_context(|| format!("failed to load MCP servers from {}", codex_home.display()))?; let new_entry = McpServerConfig { - command: command_bin, - args: command_args, - env: env_map, + transport: McpServerTransportConfig::Stdio { + command: command_bin, + args: command_args, + env: env_map, + }, startup_timeout_sec: None, tool_timeout_sec: None, }; @@ -201,16 +203,25 @@ fn run_list(config_overrides: &CliConfigOverrides, list_args: ListArgs) -> Resul let json_entries: Vec<_> = entries .into_iter() .map(|(name, cfg)| { - let env = cfg.env.as_ref().map(|env| { - env.iter() - .map(|(k, v)| (k.clone(), v.clone())) - .collect::>() - }); + let transport = match &cfg.transport { + McpServerTransportConfig::Stdio { command, args, env } => serde_json::json!({ + "type": "stdio", + "command": command, + "args": args, + "env": env, + }), + McpServerTransportConfig::StreamableHttp { url, bearer_token } => { + serde_json::json!({ + "type": "streamable_http", + "url": url, + "bearer_token": bearer_token, + }) + } + }; + serde_json::json!({ "name": name, - "command": cfg.command, - "args": cfg.args, - "env": env, + "transport": transport, "startup_timeout_sec": cfg .startup_timeout_sec .map(|timeout| timeout.as_secs_f64()), @@ -230,62 +241,111 @@ fn run_list(config_overrides: &CliConfigOverrides, list_args: ListArgs) -> Resul return Ok(()); } - let mut rows: Vec<[String; 4]> = Vec::new(); - for (name, cfg) in entries { - let args = if cfg.args.is_empty() { - "-".to_string() - } else { - cfg.args.join(" ") - }; + let mut stdio_rows: Vec<[String; 4]> = Vec::new(); + let mut http_rows: Vec<[String; 3]> = Vec::new(); - let env = match cfg.env.as_ref() { - None => "-".to_string(), - Some(map) if map.is_empty() => "-".to_string(), - Some(map) => { - let mut pairs: Vec<_> = map.iter().collect(); - pairs.sort_by(|(a, _), (b, _)| a.cmp(b)); - pairs - .into_iter() - .map(|(k, v)| format!("{k}={v}")) - .collect::>() - .join(", ") + for (name, cfg) in entries { + match &cfg.transport { + McpServerTransportConfig::Stdio { command, args, env } => { + let args_display = if args.is_empty() { + "-".to_string() + } else { + args.join(" ") + }; + let env_display = match env.as_ref() { + None => "-".to_string(), + Some(map) if map.is_empty() => "-".to_string(), + Some(map) => { + let mut pairs: Vec<_> = map.iter().collect(); + pairs.sort_by(|(a, _), (b, _)| a.cmp(b)); + pairs + .into_iter() + .map(|(k, v)| format!("{k}={v}")) + .collect::>() + .join(", ") + } + }; + stdio_rows.push([name.clone(), command.clone(), args_display, env_display]); } - }; - - rows.push([name.clone(), cfg.command.clone(), args, env]); + McpServerTransportConfig::StreamableHttp { url, bearer_token } => { + let has_bearer = if bearer_token.is_some() { + "True" + } else { + "False" + }; + http_rows.push([name.clone(), url.clone(), has_bearer.into()]); + } + } } - let mut widths = ["Name".len(), "Command".len(), "Args".len(), "Env".len()]; - for row in &rows { - for (i, cell) in row.iter().enumerate() { - widths[i] = widths[i].max(cell.len()); + if !stdio_rows.is_empty() { + let mut widths = ["Name".len(), "Command".len(), "Args".len(), "Env".len()]; + for row in &stdio_rows { + for (i, cell) in row.iter().enumerate() { + widths[i] = widths[i].max(cell.len()); + } } - } - println!( - "{: Result<( }; if get_args.json { - let env = server.env.as_ref().map(|env| { - env.iter() - .map(|(k, v)| (k.clone(), v.clone())) - .collect::>() - }); + let transport = match &server.transport { + McpServerTransportConfig::Stdio { command, args, env } => serde_json::json!({ + "type": "stdio", + "command": command, + "args": args, + "env": env, + }), + McpServerTransportConfig::StreamableHttp { url, bearer_token } => serde_json::json!({ + "type": "streamable_http", + "url": url, + "bearer_token": bearer_token, + }), + }; let output = serde_json::to_string_pretty(&serde_json::json!({ "name": get_args.name, - "command": server.command, - "args": server.args, - "env": env, + "transport": transport, "startup_timeout_sec": server .startup_timeout_sec .map(|timeout| timeout.as_secs_f64()), @@ -323,27 +389,38 @@ fn run_get(config_overrides: &CliConfigOverrides, get_args: GetArgs) -> Result<( } println!("{}", get_args.name); - println!(" command: {}", server.command); - let args = if server.args.is_empty() { - "-".to_string() - } else { - server.args.join(" ") - }; - println!(" args: {args}"); - let env_display = match server.env.as_ref() { - None => "-".to_string(), - Some(map) if map.is_empty() => "-".to_string(), - Some(map) => { - let mut pairs: Vec<_> = map.iter().collect(); - pairs.sort_by(|(a, _), (b, _)| a.cmp(b)); - pairs - .into_iter() - .map(|(k, v)| format!("{k}={v}")) - .collect::>() - .join(", ") + match &server.transport { + McpServerTransportConfig::Stdio { command, args, env } => { + println!(" transport: stdio"); + println!(" command: {command}"); + let args_display = if args.is_empty() { + "-".to_string() + } else { + args.join(" ") + }; + println!(" args: {args_display}"); + let env_display = match env.as_ref() { + None => "-".to_string(), + Some(map) if map.is_empty() => "-".to_string(), + Some(map) => { + let mut pairs: Vec<_> = map.iter().collect(); + pairs.sort_by(|(a, _), (b, _)| a.cmp(b)); + pairs + .into_iter() + .map(|(k, v)| format!("{k}={v}")) + .collect::>() + .join(", ") + } + }; + println!(" env: {env_display}"); } - }; - println!(" env: {env_display}"); + McpServerTransportConfig::StreamableHttp { url, bearer_token } => { + println!(" transport: streamable_http"); + println!(" url: {url}"); + let bearer = bearer_token.as_deref().unwrap_or("-"); + println!(" bearer_token: {bearer}"); + } + } if let Some(timeout) = server.startup_timeout_sec { println!(" startup_timeout_sec: {}", timeout.as_secs_f64()); } diff --git a/codex-rs/cli/tests/mcp_add_remove.rs b/codex-rs/cli/tests/mcp_add_remove.rs index 9e54f0d867..cf3ea9f739 100644 --- a/codex-rs/cli/tests/mcp_add_remove.rs +++ b/codex-rs/cli/tests/mcp_add_remove.rs @@ -2,6 +2,7 @@ use std::path::Path; use anyhow::Result; use codex_core::config::load_global_mcp_servers; +use codex_core::config_types::McpServerTransportConfig; use predicates::str::contains; use pretty_assertions::assert_eq; use tempfile::TempDir; @@ -26,9 +27,14 @@ fn add_and_remove_server_updates_global_config() -> Result<()> { let servers = load_global_mcp_servers(codex_home.path())?; assert_eq!(servers.len(), 1); let docs = servers.get("docs").expect("server should exist"); - assert_eq!(docs.command, "echo"); - assert_eq!(docs.args, vec!["hello".to_string()]); - assert!(docs.env.is_none()); + match &docs.transport { + McpServerTransportConfig::Stdio { command, args, env } => { + assert_eq!(command, "echo"); + assert_eq!(args, &vec!["hello".to_string()]); + assert!(env.is_none()); + } + other => panic!("unexpected transport: {other:?}"), + } let mut remove_cmd = codex_command(codex_home.path())?; remove_cmd @@ -76,7 +82,10 @@ fn add_with_env_preserves_key_order_and_values() -> Result<()> { let servers = load_global_mcp_servers(codex_home.path())?; let envy = servers.get("envy").expect("server should exist"); - let env = envy.env.as_ref().expect("env should be present"); + let env = match &envy.transport { + McpServerTransportConfig::Stdio { env: Some(env), .. } => env, + other => panic!("unexpected transport: {other:?}"), + }; assert_eq!(env.len(), 2); assert_eq!(env.get("FOO"), Some(&"bar".to_string())); diff --git a/codex-rs/cli/tests/mcp_list.rs b/codex-rs/cli/tests/mcp_list.rs index e53f42cc8f..6c83de19fa 100644 --- a/codex-rs/cli/tests/mcp_list.rs +++ b/codex-rs/cli/tests/mcp_list.rs @@ -4,6 +4,7 @@ use anyhow::Result; use predicates::str::contains; use pretty_assertions::assert_eq; use serde_json::Value as JsonValue; +use serde_json::json; use tempfile::TempDir; fn codex_command(codex_home: &Path) -> Result { @@ -58,38 +59,35 @@ fn list_and_get_render_expected_output() -> Result<()> { assert!(json_output.status.success()); let stdout = String::from_utf8(json_output.stdout)?; let parsed: JsonValue = serde_json::from_str(&stdout)?; - let array = parsed.as_array().expect("expected array"); - assert_eq!(array.len(), 1); - let entry = &array[0]; - assert_eq!(entry.get("name"), Some(&JsonValue::String("docs".into()))); assert_eq!( - entry.get("command"), - Some(&JsonValue::String("docs-server".into())) - ); - - let args = entry - .get("args") - .and_then(|v| v.as_array()) - .expect("args array"); - assert_eq!( - args, - &vec![ - JsonValue::String("--port".into()), - JsonValue::String("4000".into()) + parsed, + json!([ + { + "name": "docs", + "transport": { + "type": "stdio", + "command": "docs-server", + "args": [ + "--port", + "4000" + ], + "env": { + "TOKEN": "secret" + } + }, + "startup_timeout_sec": null, + "tool_timeout_sec": null + } ] + ) ); - let env = entry - .get("env") - .and_then(|v| v.as_object()) - .expect("env map"); - assert_eq!(env.get("TOKEN"), Some(&JsonValue::String("secret".into()))); - let mut get_cmd = codex_command(codex_home.path())?; let get_output = get_cmd.args(["mcp", "get", "docs"]).output()?; assert!(get_output.status.success()); let stdout = String::from_utf8(get_output.stdout)?; assert!(stdout.contains("docs")); + assert!(stdout.contains("transport: stdio")); assert!(stdout.contains("command: docs-server")); assert!(stdout.contains("args: --port 4000")); assert!(stdout.contains("env: TOKEN=secret")); diff --git a/codex-rs/core/src/config.rs b/codex-rs/core/src/config.rs index 5b5b60f8df..292b9f7b51 100644 --- a/codex-rs/core/src/config.rs +++ b/codex-rs/core/src/config.rs @@ -1,6 +1,7 @@ use crate::config_profile::ConfigProfile; use crate::config_types::History; use crate::config_types::McpServerConfig; +use crate::config_types::McpServerTransportConfig; use crate::config_types::Notifications; use crate::config_types::ReasoningSummaryFormat; use crate::config_types::SandboxWorkspaceWrite; @@ -314,27 +315,37 @@ pub fn write_global_mcp_servers( for (name, config) in servers { let mut entry = TomlTable::new(); entry.set_implicit(false); - entry["command"] = toml_edit::value(config.command.clone()); + match &config.transport { + McpServerTransportConfig::Stdio { command, args, env } => { + entry["command"] = toml_edit::value(command.clone()); + + if !args.is_empty() { + let mut args_array = TomlArray::new(); + for arg in args { + args_array.push(arg.clone()); + } + entry["args"] = TomlItem::Value(args_array.into()); + } - if !config.args.is_empty() { - let mut args = TomlArray::new(); - for arg in &config.args { - args.push(arg.clone()); + if let Some(env) = env + && !env.is_empty() + { + let mut env_table = TomlTable::new(); + env_table.set_implicit(false); + let mut pairs: Vec<_> = env.iter().collect(); + pairs.sort_by(|(a, _), (b, _)| a.cmp(b)); + for (key, value) in pairs { + env_table.insert(key, toml_edit::value(value.clone())); + } + entry["env"] = TomlItem::Table(env_table); + } } - entry["args"] = TomlItem::Value(args.into()); - } - - if let Some(env) = &config.env - && !env.is_empty() - { - let mut env_table = TomlTable::new(); - env_table.set_implicit(false); - let mut pairs: Vec<_> = env.iter().collect(); - pairs.sort_by(|(a, _), (b, _)| a.cmp(b)); - for (key, value) in pairs { - env_table.insert(key, toml_edit::value(value.clone())); + McpServerTransportConfig::StreamableHttp { url, bearer_token } => { + entry["url"] = toml_edit::value(url.clone()); + if let Some(token) = bearer_token { + entry["bearer_token"] = toml_edit::value(token.clone()); + } } - entry["env"] = TomlItem::Table(env_table); } if let Some(timeout) = config.startup_timeout_sec { @@ -1294,9 +1305,11 @@ exclude_slash_tmp = true servers.insert( "docs".to_string(), McpServerConfig { - command: "echo".to_string(), - args: vec!["hello".to_string()], - env: None, + transport: McpServerTransportConfig::Stdio { + command: "echo".to_string(), + args: vec!["hello".to_string()], + env: None, + }, startup_timeout_sec: Some(Duration::from_secs(3)), tool_timeout_sec: Some(Duration::from_secs(5)), }, @@ -1307,8 +1320,14 @@ exclude_slash_tmp = true let loaded = load_global_mcp_servers(codex_home.path())?; assert_eq!(loaded.len(), 1); let docs = loaded.get("docs").expect("docs entry"); - assert_eq!(docs.command, "echo"); - assert_eq!(docs.args, vec!["hello".to_string()]); + match &docs.transport { + McpServerTransportConfig::Stdio { command, args, env } => { + assert_eq!(command, "echo"); + assert_eq!(args, &vec!["hello".to_string()]); + assert!(env.is_none()); + } + other => panic!("unexpected transport {other:?}"), + } assert_eq!(docs.startup_timeout_sec, Some(Duration::from_secs(3))); assert_eq!(docs.tool_timeout_sec, Some(Duration::from_secs(5))); @@ -1342,6 +1361,134 @@ startup_timeout_ms = 2500 Ok(()) } + #[test] + fn write_global_mcp_servers_serializes_env_sorted() -> anyhow::Result<()> { + let codex_home = TempDir::new()?; + + let servers = BTreeMap::from([( + "docs".to_string(), + McpServerConfig { + transport: McpServerTransportConfig::Stdio { + command: "docs-server".to_string(), + args: vec!["--verbose".to_string()], + env: Some(HashMap::from([ + ("ZIG_VAR".to_string(), "3".to_string()), + ("ALPHA_VAR".to_string(), "1".to_string()), + ])), + }, + startup_timeout_sec: None, + tool_timeout_sec: None, + }, + )]); + + write_global_mcp_servers(codex_home.path(), &servers)?; + + let config_path = codex_home.path().join(CONFIG_TOML_FILE); + let serialized = std::fs::read_to_string(&config_path)?; + assert_eq!( + serialized, + r#"[mcp_servers.docs] +command = "docs-server" +args = ["--verbose"] + +[mcp_servers.docs.env] +ALPHA_VAR = "1" +ZIG_VAR = "3" +"# + ); + + let loaded = load_global_mcp_servers(codex_home.path())?; + let docs = loaded.get("docs").expect("docs entry"); + match &docs.transport { + McpServerTransportConfig::Stdio { command, args, env } => { + assert_eq!(command, "docs-server"); + assert_eq!(args, &vec!["--verbose".to_string()]); + let env = env + .as_ref() + .expect("env should be preserved for stdio transport"); + assert_eq!(env.get("ALPHA_VAR"), Some(&"1".to_string())); + assert_eq!(env.get("ZIG_VAR"), Some(&"3".to_string())); + } + other => panic!("unexpected transport {other:?}"), + } + + Ok(()) + } + + #[test] + fn write_global_mcp_servers_serializes_streamable_http() -> anyhow::Result<()> { + let codex_home = TempDir::new()?; + + let mut servers = BTreeMap::from([( + "docs".to_string(), + McpServerConfig { + transport: McpServerTransportConfig::StreamableHttp { + url: "https://example.com/mcp".to_string(), + bearer_token: Some("secret-token".to_string()), + }, + startup_timeout_sec: Some(Duration::from_secs(2)), + tool_timeout_sec: None, + }, + )]); + + write_global_mcp_servers(codex_home.path(), &servers)?; + + let config_path = codex_home.path().join(CONFIG_TOML_FILE); + let serialized = std::fs::read_to_string(&config_path)?; + assert_eq!( + serialized, + r#"[mcp_servers.docs] +url = "https://example.com/mcp" +bearer_token = "secret-token" +startup_timeout_sec = 2.0 +"# + ); + + let loaded = load_global_mcp_servers(codex_home.path())?; + let docs = loaded.get("docs").expect("docs entry"); + match &docs.transport { + McpServerTransportConfig::StreamableHttp { url, bearer_token } => { + assert_eq!(url, "https://example.com/mcp"); + assert_eq!(bearer_token.as_deref(), Some("secret-token")); + } + other => panic!("unexpected transport {other:?}"), + } + assert_eq!(docs.startup_timeout_sec, Some(Duration::from_secs(2))); + + servers.insert( + "docs".to_string(), + McpServerConfig { + transport: McpServerTransportConfig::StreamableHttp { + url: "https://example.com/mcp".to_string(), + bearer_token: None, + }, + startup_timeout_sec: None, + tool_timeout_sec: None, + }, + ); + write_global_mcp_servers(codex_home.path(), &servers)?; + + let serialized = std::fs::read_to_string(&config_path)?; + assert_eq!( + serialized, + r#"[mcp_servers.docs] +url = "https://example.com/mcp" +"# + ); + + let loaded = load_global_mcp_servers(codex_home.path())?; + let docs = loaded.get("docs").expect("docs entry"); + match &docs.transport { + McpServerTransportConfig::StreamableHttp { url, bearer_token } => { + assert_eq!(url, "https://example.com/mcp"); + assert!(bearer_token.is_none()); + } + other => panic!("unexpected transport {other:?}"), + } + + Ok(()) + } + #[tokio::test] async fn persist_model_selection_updates_defaults() -> anyhow::Result<()> { let codex_home = TempDir::new()?; diff --git a/codex-rs/core/src/config_types.rs b/codex-rs/core/src/config_types.rs index d273b23d69..283ae3a480 100644 --- a/codex-rs/core/src/config_types.rs +++ b/codex-rs/core/src/config_types.rs @@ -3,25 +3,20 @@ // Note this file should generally be restricted to simple struct/enum // definitions that do not contain business logic. +use serde::Deserializer; use std::collections::HashMap; use std::path::PathBuf; use std::time::Duration; use wildmatch::WildMatchPattern; use serde::Deserialize; -use serde::Deserializer; use serde::Serialize; use serde::de::Error as SerdeError; #[derive(Serialize, Debug, Clone, PartialEq)] pub struct McpServerConfig { - pub command: String, - - #[serde(default)] - pub args: Vec, - - #[serde(default)] - pub env: Option>, + #[serde(flatten)] + pub transport: McpServerTransportConfig, /// Startup timeout in seconds for initializing MCP server & initially listing tools. #[serde( @@ -43,11 +38,15 @@ impl<'de> Deserialize<'de> for McpServerConfig { { #[derive(Deserialize)] struct RawMcpServerConfig { - command: String, + command: Option, #[serde(default)] - args: Vec, + args: Option>, #[serde(default)] env: Option>, + + url: Option, + bearer_token: Option, + #[serde(default)] startup_timeout_sec: Option, #[serde(default)] @@ -67,16 +66,81 @@ impl<'de> Deserialize<'de> for McpServerConfig { (None, None) => None, }; + fn throw_if_set(transport: &str, field: &str, value: Option<&T>) -> Result<(), E> + where + E: SerdeError, + { + if value.is_none() { + return Ok(()); + } + Err(E::custom(format!( + "{field} is not supported for {transport}", + ))) + } + + let transport = match raw { + RawMcpServerConfig { + command: Some(command), + args, + env, + url, + bearer_token, + .. + } => { + throw_if_set("stdio", "url", url.as_ref())?; + throw_if_set("stdio", "bearer_token", bearer_token.as_ref())?; + McpServerTransportConfig::Stdio { + command, + args: args.unwrap_or_default(), + env, + } + } + RawMcpServerConfig { + url: Some(url), + bearer_token, + command, + args, + env, + .. + } => { + throw_if_set("streamable_http", "command", command.as_ref())?; + throw_if_set("streamable_http", "args", args.as_ref())?; + throw_if_set("streamable_http", "env", env.as_ref())?; + McpServerTransportConfig::StreamableHttp { url, bearer_token } + } + _ => return Err(SerdeError::custom("invalid transport")), + }; + Ok(Self { - command: raw.command, - args: raw.args, - env: raw.env, + transport, startup_timeout_sec, tool_timeout_sec: raw.tool_timeout_sec, }) } } +#[derive(Serialize, Deserialize, Debug, Clone, PartialEq)] +#[serde(untagged, deny_unknown_fields, rename_all = "snake_case")] +pub enum McpServerTransportConfig { + /// https://modelcontextprotocol.io/specification/2025-06-18/basic/transports#stdio + Stdio { + command: String, + #[serde(default)] + args: Vec, + #[serde(default, skip_serializing_if = "Option::is_none")] + env: Option>, + }, + /// https://modelcontextprotocol.io/specification/2025-06-18/basic/transports#streamable-http + StreamableHttp { + url: String, + /// A plain text bearer token to use for authentication. + /// This bearer token will be included in the HTTP request header as an `Authorization: Bearer ` header. + /// This should be used with caution because it lives on disk in clear text. + #[serde(default, skip_serializing_if = "Option::is_none")] + bearer_token: Option, + }, +} + mod option_duration_secs { use serde::Deserialize; use serde::Deserializer; @@ -303,3 +367,139 @@ pub enum ReasoningSummaryFormat { None, Experimental, } + +#[cfg(test)] +mod tests { + use super::*; + use pretty_assertions::assert_eq; + + #[test] + fn deserialize_stdio_command_server_config() { + let cfg: McpServerConfig = toml::from_str( + r#" + command = "echo" + "#, + ) + .expect("should deserialize command config"); + + assert_eq!( + cfg.transport, + McpServerTransportConfig::Stdio { + command: "echo".to_string(), + args: vec![], + env: None + } + ); + } + + #[test] + fn deserialize_stdio_command_server_config_with_args() { + let cfg: McpServerConfig = toml::from_str( + r#" + command = "echo" + args = ["hello", "world"] + "#, + ) + .expect("should deserialize command config"); + + assert_eq!( + cfg.transport, + McpServerTransportConfig::Stdio { + command: "echo".to_string(), + args: vec!["hello".to_string(), "world".to_string()], + env: None + } + ); + } + + #[test] + fn deserialize_stdio_command_server_config_with_arg_with_args_and_env() { + let cfg: McpServerConfig = toml::from_str( + r#" + command = "echo" + args = ["hello", "world"] + env = { "FOO" = "BAR" } + "#, + ) + .expect("should deserialize command config"); + + assert_eq!( + cfg.transport, + McpServerTransportConfig::Stdio { + command: "echo".to_string(), + args: vec!["hello".to_string(), "world".to_string()], + env: Some(HashMap::from([("FOO".to_string(), "BAR".to_string())])) + } + ); + } + + #[test] + fn deserialize_streamable_http_server_config() { + let cfg: McpServerConfig = toml::from_str( + r#" + url = "https://example.com/mcp" + "#, + ) + .expect("should deserialize http config"); + + assert_eq!( + cfg.transport, + McpServerTransportConfig::StreamableHttp { + url: "https://example.com/mcp".to_string(), + bearer_token: None + } + ); + } + + #[test] + fn deserialize_streamable_http_server_config_with_bearer_token() { + let cfg: McpServerConfig = toml::from_str( + r#" + url = "https://example.com/mcp" + bearer_token = "secret" + "#, + ) + .expect("should deserialize http config"); + + assert_eq!( + cfg.transport, + McpServerTransportConfig::StreamableHttp { + url: "https://example.com/mcp".to_string(), + bearer_token: Some("secret".to_string()) + } + ); + } + + #[test] + fn deserialize_rejects_command_and_url() { + toml::from_str::( + r#" + command = "echo" + url = "https://example.com" + "#, + ) + .expect_err("should reject command+url"); + } + + #[test] + fn deserialize_rejects_env_for_http_transport() { + toml::from_str::( + r#" + url = "https://example.com" + env = { "FOO" = "BAR" } + "#, + ) + .expect_err("should reject env for http transport"); + } + + #[test] + fn deserialize_rejects_bearer_token_for_stdio_transport() { + toml::from_str::( + r#" + command = "echo" + bearer_token = "secret" + "#, + ) + .expect_err("should reject bearer token for stdio transport"); + } +} diff --git a/codex-rs/core/src/mcp_connection_manager.rs b/codex-rs/core/src/mcp_connection_manager.rs index 5648e20b3b..015ca8c6bb 100644 --- a/codex-rs/core/src/mcp_connection_manager.rs +++ b/codex-rs/core/src/mcp_connection_manager.rs @@ -29,6 +29,7 @@ use tracing::info; use tracing::warn; use crate::config_types::McpServerConfig; +use crate::config_types::McpServerTransportConfig; /// Delimiter used to separate the server name from the tool name in a fully /// qualified tool name. @@ -121,6 +122,17 @@ impl McpClientAdapter { } } + async fn new_streamable_http_client( + url: String, + bearer_token: Option, + params: mcp_types::InitializeRequestParams, + startup_timeout: Duration, + ) -> Result { + let client = Arc::new(RmcpClient::new_streamable_http_client(url, bearer_token)?); + client.initialize(params, Some(startup_timeout)).await?; + Ok(McpClientAdapter::Rmcp(client)) + } + async fn list_tools( &self, params: Option, @@ -176,8 +188,6 @@ impl McpConnectionManager { return Ok((Self::default(), ClientStartErrors::default())); } - tracing::error!("new mcp_servers: {mcp_servers:?} use_rmcp_client: {use_rmcp_client}"); - // Launch all configured servers concurrently. let mut join_set = JoinSet::new(); let mut errors = ClientStartErrors::new(); @@ -193,16 +203,24 @@ impl McpConnectionManager { continue; } + if matches!( + cfg.transport, + McpServerTransportConfig::StreamableHttp { .. } + ) && !use_rmcp_client + { + info!( + "skipping MCP server `{}` configured with url because rmcp client is disabled", + server_name + ); + continue; + } + let startup_timeout = cfg.startup_timeout_sec.unwrap_or(DEFAULT_STARTUP_TIMEOUT); let tool_timeout = cfg.tool_timeout_sec.unwrap_or(DEFAULT_TOOL_TIMEOUT); let use_rmcp_client_flag = use_rmcp_client; join_set.spawn(async move { - let McpServerConfig { - command, args, env, .. - } = cfg; - let command_os: OsString = command.into(); - let args_os: Vec = args.into_iter().map(Into::into).collect(); + let McpServerConfig { transport, .. } = cfg; let params = mcp_types::InitializeRequestParams { capabilities: ClientCapabilities { experimental: None, @@ -224,15 +242,30 @@ impl McpConnectionManager { protocol_version: mcp_types::MCP_SCHEMA_VERSION.to_owned(), }; - let client = McpClientAdapter::new_stdio_client( - use_rmcp_client_flag, - command_os, - args_os, - env, - params, - startup_timeout, - ) - .await + let client = match transport { + McpServerTransportConfig::Stdio { command, args, env } => { + let command_os: OsString = command.into(); + let args_os: Vec = args.into_iter().map(Into::into).collect(); + McpClientAdapter::new_stdio_client( + use_rmcp_client_flag, + command_os, + args_os, + env, + params.clone(), + startup_timeout, + ) + .await + } + McpServerTransportConfig::StreamableHttp { url, bearer_token } => { + McpClientAdapter::new_streamable_http_client( + url, + bearer_token, + params, + startup_timeout, + ) + .await + } + } .map(|c| (c, startup_timeout)); ((server_name, tool_timeout), client) diff --git a/codex-rs/core/tests/suite/rmcp_client.rs b/codex-rs/core/tests/suite/rmcp_client.rs index 2ebe9f011c..aed84ec603 100644 --- a/codex-rs/core/tests/suite/rmcp_client.rs +++ b/codex-rs/core/tests/suite/rmcp_client.rs @@ -1,7 +1,10 @@ use std::collections::HashMap; +use std::net::TcpListener; use std::time::Duration; use codex_core::config_types::McpServerConfig; +use codex_core::config_types::McpServerTransportConfig; + use codex_core::protocol::AskForApproval; use codex_core::protocol::EventMsg; use codex_core::protocol::InputItem; @@ -16,10 +19,15 @@ use core_test_support::wait_for_event; use core_test_support::wait_for_event_with_timeout; use escargot::CargoBuild; use serde_json::Value; +use tokio::net::TcpStream; +use tokio::process::Child; +use tokio::process::Command; +use tokio::time::Instant; +use tokio::time::sleep; use wiremock::matchers::any; -#[tokio::test(flavor = "multi_thread", worker_threads = 2)] -async fn rmcp_tool_call_round_trip() -> anyhow::Result<()> { +#[tokio::test(flavor = "multi_thread", worker_threads = 1)] +async fn stdio_server_round_trip() -> anyhow::Result<()> { skip_if_no_network!(Ok(())); let server = responses::start_mock_server().await; @@ -54,7 +62,7 @@ async fn rmcp_tool_call_round_trip() -> anyhow::Result<()> { let expected_env_value = "propagated-env"; let rmcp_test_server_bin = CargoBuild::new() .package("codex-rmcp-client") - .bin("rmcp_test_server") + .bin("test_stdio_server") .run()? .path() .to_string_lossy() @@ -66,12 +74,14 @@ async fn rmcp_tool_call_round_trip() -> anyhow::Result<()> { config.mcp_servers.insert( server_name.to_string(), McpServerConfig { - command: rmcp_test_server_bin.clone(), - args: Vec::new(), - env: Some(HashMap::from([( - "MCP_TEST_VALUE".to_string(), - expected_env_value.to_string(), - )])), + transport: McpServerTransportConfig::Stdio { + command: rmcp_test_server_bin.clone(), + args: Vec::new(), + env: Some(HashMap::from([( + "MCP_TEST_VALUE".to_string(), + expected_env_value.to_string(), + )])), + }, startup_timeout_sec: Some(Duration::from_secs(10)), tool_timeout_sec: None, }, @@ -97,18 +107,164 @@ async fn rmcp_tool_call_round_trip() -> anyhow::Result<()> { }) .await?; - eprintln!("waiting for mcp tool call begin event"); let begin_event = wait_for_event_with_timeout( &fixture.codex, - |ev| { - eprintln!("ev: {ev:?}"); - matches!(ev, EventMsg::McpToolCallBegin(_)) - }, + |ev| matches!(ev, EventMsg::McpToolCallBegin(_)), + Duration::from_secs(10), + ) + .await; + + let EventMsg::McpToolCallBegin(begin) = begin_event else { + unreachable!("event guard guarantees McpToolCallBegin"); + }; + assert_eq!(begin.invocation.server, server_name); + assert_eq!(begin.invocation.tool, "echo"); + + let end_event = wait_for_event(&fixture.codex, |ev| { + matches!(ev, EventMsg::McpToolCallEnd(_)) + }) + .await; + let EventMsg::McpToolCallEnd(end) = end_event else { + unreachable!("event guard guarantees McpToolCallEnd"); + }; + + let result = end + .result + .as_ref() + .expect("rmcp echo tool should return success"); + assert_eq!(result.is_error, Some(false)); + assert!( + result.content.is_empty(), + "content should default to an empty array" + ); + + let structured = result + .structured_content + .as_ref() + .expect("structured content"); + let Value::Object(map) = structured else { + panic!("structured content should be an object: {structured:?}"); + }; + let echo_value = map + .get("echo") + .and_then(Value::as_str) + .expect("echo payload present"); + assert_eq!(echo_value, "ECHOING: ping"); + let env_value = map + .get("env") + .and_then(Value::as_str) + .expect("env snapshot inserted"); + assert_eq!(env_value, expected_env_value); + + wait_for_event(&fixture.codex, |ev| matches!(ev, EventMsg::TaskComplete(_))).await; + + server.verify().await; + + Ok(()) +} + +#[tokio::test(flavor = "multi_thread", worker_threads = 1)] +async fn streamable_http_tool_call_round_trip() -> anyhow::Result<()> { + skip_if_no_network!(Ok(())); + + let server = responses::start_mock_server().await; + + let call_id = "call-456"; + let server_name = "rmcp_http"; + let tool_name = format!("{server_name}__echo"); + + mount_sse_once( + &server, + any(), + responses::sse(vec![ + serde_json::json!({ + "type": "response.created", + "response": {"id": "resp-1"} + }), + responses::ev_function_call(call_id, &tool_name, "{\"message\":\"ping\"}"), + responses::ev_completed("resp-1"), + ]), + ) + .await; + mount_sse_once( + &server, + any(), + responses::sse(vec![ + responses::ev_assistant_message( + "msg-1", + "rmcp streamable http echo tool completed successfully.", + ), + responses::ev_completed("resp-2"), + ]), + ) + .await; + + let expected_env_value = "propagated-env-http"; + let rmcp_http_server_bin = CargoBuild::new() + .package("codex-rmcp-client") + .bin("test_streamable_http_server") + .run()? + .path() + .to_string_lossy() + .into_owned(); + + let listener = TcpListener::bind("127.0.0.1:0")?; + let port = listener.local_addr()?.port(); + drop(listener); + let bind_addr = format!("127.0.0.1:{port}"); + let server_url = format!("http://{bind_addr}/mcp"); + + let mut http_server_child = Command::new(&rmcp_http_server_bin) + .kill_on_drop(true) + .env("MCP_STREAMABLE_HTTP_BIND_ADDR", &bind_addr) + .env("MCP_TEST_VALUE", expected_env_value) + .spawn()?; + + wait_for_streamable_http_server(&mut http_server_child, &bind_addr, Duration::from_secs(5)) + .await?; + + let fixture = test_codex() + .with_config(move |config| { + config.use_experimental_use_rmcp_client = true; + config.mcp_servers.insert( + server_name.to_string(), + McpServerConfig { + transport: McpServerTransportConfig::StreamableHttp { + url: server_url, + bearer_token: None, + }, + startup_timeout_sec: Some(Duration::from_secs(10)), + tool_timeout_sec: None, + }, + ); + }) + .build(&server) + .await?; + let session_model = fixture.session_configured.model.clone(); + + fixture + .codex + .submit(Op::UserTurn { + items: vec![InputItem::Text { + text: "call the rmcp streamable http echo tool".into(), + }], + final_output_json_schema: None, + cwd: fixture.cwd.path().to_path_buf(), + approval_policy: AskForApproval::Never, + sandbox_policy: SandboxPolicy::DangerFullAccess, + model: session_model, + effort: None, + summary: ReasoningSummary::Auto, + }) + .await?; + + let begin_event = wait_for_event_with_timeout( + &fixture.codex, + |ev| matches!(ev, EventMsg::McpToolCallBegin(_)), Duration::from_secs(10), ) .await; - eprintln!("mcp tool call begin event: {begin_event:?}"); let EventMsg::McpToolCallBegin(begin) = begin_event else { unreachable!("event guard guarantees McpToolCallBegin"); }; @@ -119,7 +275,6 @@ async fn rmcp_tool_call_round_trip() -> anyhow::Result<()> { matches!(ev, EventMsg::McpToolCallEnd(_)) }) .await; - eprintln!("end_event: {end_event:?}"); let EventMsg::McpToolCallEnd(end) = end_event else { unreachable!("event guard guarantees McpToolCallEnd"); }; @@ -145,18 +300,72 @@ async fn rmcp_tool_call_round_trip() -> anyhow::Result<()> { .get("echo") .and_then(Value::as_str) .expect("echo payload present"); - assert_eq!(echo_value, "ping"); + assert_eq!(echo_value, "ECHOING: ping"); let env_value = map .get("env") .and_then(Value::as_str) .expect("env snapshot inserted"); assert_eq!(env_value, expected_env_value); - let task_complete_event = - wait_for_event(&fixture.codex, |ev| matches!(ev, EventMsg::TaskComplete(_))).await; - eprintln!("task_complete_event: {task_complete_event:?}"); + wait_for_event(&fixture.codex, |ev| matches!(ev, EventMsg::TaskComplete(_))).await; server.verify().await; + match http_server_child.try_wait() { + Ok(Some(_)) => {} + Ok(None) => { + let _ = http_server_child.kill().await; + } + Err(error) => { + eprintln!("failed to check streamable http server status: {error}"); + let _ = http_server_child.kill().await; + } + } + if let Err(error) = http_server_child.wait().await { + eprintln!("failed to await streamable http server shutdown: {error}"); + } + Ok(()) } + +async fn wait_for_streamable_http_server( + server_child: &mut Child, + address: &str, + timeout: Duration, +) -> anyhow::Result<()> { + let deadline = Instant::now() + timeout; + + loop { + if let Some(status) = server_child.try_wait()? { + return Err(anyhow::anyhow!( + "streamable HTTP server exited early with status {status}" + )); + } + + let remaining = deadline.saturating_duration_since(Instant::now()); + + if remaining.is_zero() { + return Err(anyhow::anyhow!( + "timed out waiting for streamable HTTP server at {address}: deadline reached" + )); + } + + match tokio::time::timeout(remaining, TcpStream::connect(address)).await { + Ok(Ok(_)) => return Ok(()), + Ok(Err(error)) => { + if Instant::now() >= deadline { + return Err(anyhow::anyhow!( + "timed out waiting for streamable HTTP server at {address}: {error}" + )); + } + } + Err(_) => { + return Err(anyhow::anyhow!( + "timed out waiting for streamable HTTP server at {address}: connect call timed out" + )); + } + } + + sleep(Duration::from_millis(50)).await; + } +} diff --git a/codex-rs/justfile b/codex-rs/justfile index 850737efd6..9ddc4a37aa 100644 --- a/codex-rs/justfile +++ b/codex-rs/justfile @@ -5,6 +5,7 @@ help: just -l # `codex` +alias c := codex codex *args: cargo run --bin codex -- "$@" @@ -27,6 +28,9 @@ fmt: fix *args: cargo clippy --fix --all-features --tests --allow-dirty "$@" +clippy: + cargo clippy --all-features --tests "$@" + install: rustup show active-toolchain cargo fetch diff --git a/codex-rs/rmcp-client/Cargo.toml b/codex-rs/rmcp-client/Cargo.toml index da9989e531..a377b94f07 100644 --- a/codex-rs/rmcp-client/Cargo.toml +++ b/codex-rs/rmcp-client/Cargo.toml @@ -16,6 +16,15 @@ rmcp = { version = "0.7.0", default-features = false, features = [ "schemars", "server", "transport-child-process", + "transport-streamable-http-client-reqwest", + "transport-streamable-http-server", +] } +axum = { version = "0.8", default-features = false, features = ["http1", "tokio"] } +futures = { version = "0.3", default-features = false, features = ["std"] } +reqwest = { version = "0.12", default-features = false, features = [ + "json", + "stream", + "rustls-tls", ] } serde = { version = "1", features = ["derive"] } serde_json = "1" diff --git a/codex-rs/rmcp-client/src/bin/test_stdio_server.rs b/codex-rs/rmcp-client/src/bin/test_stdio_server.rs new file mode 100644 index 0000000000..2d380fa54e --- /dev/null +++ b/codex-rs/rmcp-client/src/bin/test_stdio_server.rs @@ -0,0 +1,142 @@ +use std::borrow::Cow; +use std::collections::HashMap; +use std::sync::Arc; + +use rmcp::ErrorData as McpError; +use rmcp::ServiceExt; +use rmcp::handler::server::ServerHandler; +use rmcp::model::CallToolRequestParam; +use rmcp::model::CallToolResult; +use rmcp::model::JsonObject; +use rmcp::model::ListToolsResult; +use rmcp::model::PaginatedRequestParam; +use rmcp::model::ServerCapabilities; +use rmcp::model::ServerInfo; +use rmcp::model::Tool; +use serde::Deserialize; +use serde_json::json; +use tokio::task; + +#[derive(Clone)] +struct TestToolServer { + tools: Arc>, +} +pub fn stdio() -> (tokio::io::Stdin, tokio::io::Stdout) { + (tokio::io::stdin(), tokio::io::stdout()) +} +impl TestToolServer { + fn new() -> Self { + let tools = vec![Self::echo_tool()]; + Self { + tools: Arc::new(tools), + } + } + + fn echo_tool() -> Tool { + #[expect(clippy::expect_used)] + let schema: JsonObject = serde_json::from_value(json!({ + "type": "object", + "properties": { + "message": { "type": "string" }, + "env_var": { "type": "string" } + }, + "required": ["message"], + "additionalProperties": false + })) + .expect("echo tool schema should deserialize"); + + Tool::new( + Cow::Borrowed("echo"), + Cow::Borrowed("Echo back the provided message and include environment data."), + Arc::new(schema), + ) + } +} + +#[derive(Deserialize)] +struct EchoArgs { + message: String, + #[allow(dead_code)] + env_var: Option, +} + +impl ServerHandler for TestToolServer { + fn get_info(&self) -> ServerInfo { + ServerInfo { + capabilities: ServerCapabilities::builder() + .enable_tools() + .enable_tool_list_changed() + .build(), + ..ServerInfo::default() + } + } + + fn list_tools( + &self, + _request: Option, + _context: rmcp::service::RequestContext, + ) -> impl std::future::Future> + Send + '_ { + let tools = self.tools.clone(); + async move { + Ok(ListToolsResult { + tools: (*tools).clone(), + next_cursor: None, + }) + } + } + + async fn call_tool( + &self, + request: CallToolRequestParam, + _context: rmcp::service::RequestContext, + ) -> Result { + match request.name.as_ref() { + "echo" => { + let args: EchoArgs = match request.arguments { + Some(arguments) => serde_json::from_value(serde_json::Value::Object( + arguments.into_iter().collect(), + )) + .map_err(|err| McpError::invalid_params(err.to_string(), None))?, + None => { + return Err(McpError::invalid_params( + "missing arguments for echo tool", + None, + )); + } + }; + + let env_snapshot: HashMap = std::env::vars().collect(); + let structured_content = json!({ + "echo": format!("ECHOING: {}", args.message), + "env": env_snapshot.get("MCP_TEST_VALUE"), + }); + + Ok(CallToolResult { + content: Vec::new(), + structured_content: Some(structured_content), + is_error: Some(false), + meta: None, + }) + } + other => Err(McpError::invalid_params( + format!("unknown tool: {other}"), + None, + )), + } + } +} + +#[tokio::main] +async fn main() -> Result<(), Box> { + eprintln!("starting rmcp test server"); + // Run the server with STDIO transport. If the client disconnects we simply + // bubble up the error so the process exits. + let service = TestToolServer::new(); + let running = service.serve(stdio()).await?; + + // Wait for the client to finish interacting with the server. + running.waiting().await?; + // Drain background tasks to ensure clean shutdown. + task::yield_now().await; + Ok(()) +} diff --git a/codex-rs/rmcp-client/src/bin/test_streamable_http_server.rs b/codex-rs/rmcp-client/src/bin/test_streamable_http_server.rs new file mode 100644 index 0000000000..eedd1cb1e3 --- /dev/null +++ b/codex-rs/rmcp-client/src/bin/test_streamable_http_server.rs @@ -0,0 +1,167 @@ +use std::borrow::Cow; +use std::collections::HashMap; +use std::io::ErrorKind; +use std::net::SocketAddr; +use std::sync::Arc; + +use axum::Router; +use rmcp::ErrorData as McpError; +use rmcp::handler::server::ServerHandler; +use rmcp::model::CallToolRequestParam; +use rmcp::model::CallToolResult; +use rmcp::model::JsonObject; +use rmcp::model::ListToolsResult; +use rmcp::model::PaginatedRequestParam; +use rmcp::model::ServerCapabilities; +use rmcp::model::ServerInfo; +use rmcp::model::Tool; +use rmcp::transport::StreamableHttpServerConfig; +use rmcp::transport::StreamableHttpService; +use rmcp::transport::streamable_http_server::session::local::LocalSessionManager; +use serde::Deserialize; +use serde_json::json; +use tokio::task; + +#[derive(Clone)] +struct TestToolServer { + tools: Arc>, +} + +impl TestToolServer { + fn new() -> Self { + let tools = vec![Self::echo_tool()]; + Self { + tools: Arc::new(tools), + } + } + + fn echo_tool() -> Tool { + #[expect(clippy::expect_used)] + let schema: JsonObject = serde_json::from_value(json!({ + "type": "object", + "properties": { + "message": { "type": "string" }, + "env_var": { "type": "string" } + }, + "required": ["message"], + "additionalProperties": false + })) + .expect("echo tool schema should deserialize"); + + Tool::new( + Cow::Borrowed("echo"), + Cow::Borrowed("Echo back the provided message and include environment data."), + Arc::new(schema), + ) + } +} + +#[derive(Deserialize)] +struct EchoArgs { + message: String, + #[allow(dead_code)] + env_var: Option, +} + +impl ServerHandler for TestToolServer { + fn get_info(&self) -> ServerInfo { + ServerInfo { + capabilities: ServerCapabilities::builder() + .enable_tools() + .enable_tool_list_changed() + .build(), + ..ServerInfo::default() + } + } + + fn list_tools( + &self, + _request: Option, + _context: rmcp::service::RequestContext, + ) -> impl std::future::Future> + Send + '_ { + let tools = self.tools.clone(); + async move { + Ok(ListToolsResult { + tools: (*tools).clone(), + next_cursor: None, + }) + } + } + + async fn call_tool( + &self, + request: CallToolRequestParam, + _context: rmcp::service::RequestContext, + ) -> Result { + match request.name.as_ref() { + "echo" => { + let args: EchoArgs = match request.arguments { + Some(arguments) => serde_json::from_value(serde_json::Value::Object( + arguments.into_iter().collect(), + )) + .map_err(|err| McpError::invalid_params(err.to_string(), None))?, + None => { + return Err(McpError::invalid_params( + "missing arguments for echo tool", + None, + )); + } + }; + + let env_snapshot: HashMap = std::env::vars().collect(); + let structured_content = json!({ + "echo": format!("ECHOING: {}", args.message), + "env": env_snapshot.get("MCP_TEST_VALUE"), + }); + + Ok(CallToolResult { + content: Vec::new(), + structured_content: Some(structured_content), + is_error: Some(false), + meta: None, + }) + } + other => Err(McpError::invalid_params( + format!("unknown tool: {other}"), + None, + )), + } + } +} + +fn parse_bind_addr() -> Result> { + let default_addr = "127.0.0.1:3920"; + let bind_addr = std::env::var("MCP_STREAMABLE_HTTP_BIND_ADDR") + .or_else(|_| std::env::var("BIND_ADDR")) + .unwrap_or_else(|_| default_addr.to_string()); + Ok(bind_addr.parse()?) +} + +#[tokio::main] +async fn main() -> Result<(), Box> { + let bind_addr = parse_bind_addr()?; + let listener = match tokio::net::TcpListener::bind(&bind_addr).await { + Ok(listener) => listener, + Err(err) if err.kind() == ErrorKind::PermissionDenied => { + eprintln!( + "failed to bind to {bind_addr}: {err}. make sure the process has network access" + ); + return Ok(()); + } + Err(err) => return Err(err.into()), + }; + eprintln!("starting rmcp streamable http test server on http://{bind_addr}/mcp"); + + let router = Router::new().nest_service( + "/mcp", + StreamableHttpService::new( + || Ok(TestToolServer::new()), + Arc::new(LocalSessionManager::default()), + StreamableHttpServerConfig::default(), + ), + ); + + axum::serve(listener, router).await?; + task::yield_now().await; + Ok(()) +} diff --git a/codex-rs/rmcp-client/src/rmcp_client.rs b/codex-rs/rmcp-client/src/rmcp_client.rs index c7ac1ecc9a..e127029472 100644 --- a/codex-rs/rmcp-client/src/rmcp_client.rs +++ b/codex-rs/rmcp-client/src/rmcp_client.rs @@ -7,6 +7,7 @@ use std::time::Duration; use anyhow::Result; use anyhow::anyhow; +use futures::FutureExt; use mcp_types::CallToolRequestParams; use mcp_types::CallToolResult; use mcp_types::InitializeRequestParams; @@ -19,7 +20,9 @@ use rmcp::model::PaginatedRequestParam; use rmcp::service::RoleClient; use rmcp::service::RunningService; use rmcp::service::{self}; +use rmcp::transport::StreamableHttpClientTransport; use rmcp::transport::child_process::TokioChildProcess; +use rmcp::transport::streamable_http_client::StreamableHttpClientTransportConfig; use tokio::io::AsyncBufReadExt; use tokio::io::BufReader; use tokio::process::Command; @@ -35,9 +38,14 @@ use crate::utils::convert_to_rmcp; use crate::utils::create_env_for_mcp_server; use crate::utils::run_with_timeout; +enum PendingTransport { + ChildProcess(TokioChildProcess), + StreamableHttp(StreamableHttpClientTransport), +} + enum ClientState { Connecting { - transport: Option, + transport: Option, }, Ready { service: Arc>, @@ -90,7 +98,22 @@ impl RmcpClient { Ok(Self { state: Mutex::new(ClientState::Connecting { - transport: Some(transport), + transport: Some(PendingTransport::ChildProcess(transport)), + }), + }) + } + + pub fn new_streamable_http_client(url: String, bearer_token: Option) -> Result { + let mut config = StreamableHttpClientTransportConfig::with_uri(url); + if let Some(token) = bearer_token { + config = config.auth_header(format!("Bearer {token}")); + } + + let transport = StreamableHttpClientTransport::from_config(config); + + Ok(Self { + state: Mutex::new(ClientState::Connecting { + transport: Some(PendingTransport::StreamableHttp(transport)), }), }) } @@ -116,7 +139,14 @@ impl RmcpClient { let client_info = convert_to_rmcp::<_, InitializeRequestParam>(params.clone())?; let client_handler = LoggingClientHandler::new(client_info); - let service_future = service::serve_client(client_handler, transport); + let service_future = match transport { + PendingTransport::ChildProcess(transport) => { + service::serve_client(client_handler.clone(), transport).boxed() + } + PendingTransport::StreamableHttp(transport) => { + service::serve_client(client_handler, transport).boxed() + } + }; let service = match timeout { Some(duration) => time::timeout(duration, service_future) diff --git a/codex-rs/tui/src/history_cell.rs b/codex-rs/tui/src/history_cell.rs index d95663d9fc..d374aaa104 100644 --- a/codex-rs/tui/src/history_cell.rs +++ b/codex-rs/tui/src/history_cell.rs @@ -19,6 +19,7 @@ use crate::wrapping::word_wrap_line; use crate::wrapping::word_wrap_lines; use base64::Engine; use codex_core::config::Config; +use codex_core::config_types::McpServerTransportConfig; use codex_core::config_types::ReasoningSummaryFormat; use codex_core::plan_tool::PlanItemArg; use codex_core::plan_tool::StepStatus; @@ -848,10 +849,19 @@ pub(crate) fn new_mcp_tools_output( lines.push(vec![" • Server: ".into(), server.clone().into()].into()); - if !cfg.command.is_empty() { - let cmd_display = format!("{} {}", cfg.command, cfg.args.join(" ")); - - lines.push(vec![" • Command: ".into(), cmd_display.into()].into()); + match &cfg.transport { + McpServerTransportConfig::Stdio { command, args, .. } => { + let args_suffix = if args.is_empty() { + String::new() + } else { + format!(" {}", args.join(" ")) + }; + let cmd_display = format!("{command}{args_suffix}"); + lines.push(vec![" • Command: ".into(), cmd_display.into()].into()); + } + McpServerTransportConfig::StreamableHttp { url, .. } => { + lines.push(vec![" • URL: ".into(), url.clone().into()].into()); + } } if names.is_empty() {