diff --git a/codex-rs/core/src/codex.rs b/codex-rs/core/src/codex.rs index e5b4e4c316..4019da5493 100644 --- a/codex-rs/core/src/codex.rs +++ b/codex-rs/core/src/codex.rs @@ -34,6 +34,7 @@ use codex_protocol::protocol::SessionSource; use codex_protocol::protocol::TaskStartedEvent; use codex_protocol::protocol::TurnAbortReason; use codex_protocol::protocol::TurnContextItem; +use codex_rmcp_client::ElicitationResponse; use futures::future::BoxFuture; use futures::prelude::*; use futures::stream::FuturesOrdered; @@ -44,6 +45,7 @@ use mcp_types::ListResourcesRequestParams; use mcp_types::ListResourcesResult; use mcp_types::ReadResourceRequestParams; use mcp_types::ReadResourceResult; +use mcp_types::RequestId; use serde_json; use serde_json::Value; use tokio::sync::Mutex; @@ -938,6 +940,19 @@ impl Session { } } + pub async fn resolve_elicitation( + &self, + server_name: String, + id: RequestId, + response: ElicitationResponse, + ) -> anyhow::Result<()> { + self.services + .mcp_connection_manager + .read() + .await + .resolve_elicitation(server_name, id, response) + } + /// Records input items: always append to conversation history and /// persist these response items to rollout. pub(crate) async fn record_conversation_items( @@ -1406,6 +1421,13 @@ async fn submission_loop(sess: Arc, config: Arc, rx_sub: Receiv ) .await; } + Op::ResolveElicitation { + server_name, + request_id, + decision, + } => { + handlers::resolve_elicitation(&sess, server_name, request_id, decision).await; + } Op::Shutdown => { if handlers::shutdown(&sess, sub.id.clone()).await { break; @@ -1444,6 +1466,9 @@ mod handlers { use codex_protocol::protocol::TurnAbortReason; use codex_protocol::user_input::UserInput; + use codex_rmcp_client::ElicitationAction; + use codex_rmcp_client::ElicitationResponse; + use mcp_types::RequestId; use std::sync::Arc; use tracing::info; use tracing::warn; @@ -1527,6 +1552,32 @@ mod handlers { *previous_context = Some(turn_context); } + pub async fn resolve_elicitation( + sess: &Arc, + server_name: String, + request_id: RequestId, + decision: codex_protocol::approvals::ElicitationAction, + ) { + let action = match decision { + codex_protocol::approvals::ElicitationAction::Accept => ElicitationAction::Accept, + codex_protocol::approvals::ElicitationAction::Decline => ElicitationAction::Decline, + codex_protocol::approvals::ElicitationAction::Cancel => ElicitationAction::Cancel, + }; + let response = ElicitationResponse { + action, + content: None, + }; + if let Err(err) = sess + .resolve_elicitation(server_name, request_id, response) + .await + { + warn!( + error = %err, + "failed to resolve elicitation request in session" + ); + } + } + pub async fn exec_approval(sess: &Arc, id: String, decision: ReviewDecision) { match decision { ReviewDecision::Abort => { diff --git a/codex-rs/core/src/mcp_connection_manager.rs b/codex-rs/core/src/mcp_connection_manager.rs index d8869e5e9f..e1b05cef48 100644 --- a/codex-rs/core/src/mcp_connection_manager.rs +++ b/codex-rs/core/src/mcp_connection_manager.rs @@ -11,6 +11,7 @@ use std::collections::HashSet; use std::env; use std::ffi::OsString; use std::sync::Arc; +use std::sync::Mutex; use std::time::Duration; use crate::mcp::auth::McpAuthStatusEntry; @@ -20,14 +21,17 @@ use anyhow::anyhow; use async_channel::Sender; use codex_async_utils::CancelErr; use codex_async_utils::OrCancelExt; +use codex_protocol::approvals::ElicitationRequestEvent; use codex_protocol::protocol::Event; use codex_protocol::protocol::EventMsg; use codex_protocol::protocol::McpStartupCompleteEvent; use codex_protocol::protocol::McpStartupFailure; use codex_protocol::protocol::McpStartupStatus; use codex_protocol::protocol::McpStartupUpdateEvent; +use codex_rmcp_client::ElicitationResponse; use codex_rmcp_client::OAuthCredentialsStoreMode; use codex_rmcp_client::RmcpClient; +use codex_rmcp_client::SendElicitation; use futures::future::BoxFuture; use futures::future::FutureExt; use futures::future::Shared; @@ -39,6 +43,7 @@ use mcp_types::ListResourcesRequestParams; use mcp_types::ListResourcesResult; use mcp_types::ReadResourceRequestParams; use mcp_types::ReadResourceResult; +use mcp_types::RequestId; use mcp_types::Resource; use mcp_types::ResourceTemplate; use mcp_types::Tool; @@ -46,6 +51,7 @@ use mcp_types::Tool; use serde_json::json; use sha1::Digest; use sha1::Sha1; +use tokio::sync::oneshot; use tokio::task::JoinSet; use tokio_util::sync::CancellationToken; use tracing::warn; @@ -110,6 +116,58 @@ pub(crate) struct ToolInfo { pub(crate) tool: Tool, } +type ResponderMap = HashMap<(String, RequestId), oneshot::Sender>; + +#[derive(Clone, Default)] +struct ElicitationRequestManager { + requests: Arc>, +} + +impl ElicitationRequestManager { + fn resolve( + &self, + server_name: String, + id: RequestId, + response: ElicitationResponse, + ) -> Result<()> { + self.requests + .lock() + .map_err(|e| anyhow!("failed to lock elicitation requests: {e:?}"))? + .remove(&(server_name, id)) + .ok_or_else(|| anyhow!("elicitation request not found"))? + .send(response) + .map_err(|e| anyhow!("failed to send elicitation response: {e:?}")) + } + + fn make_sender(&self, server_name: String, tx_event: Sender) -> SendElicitation { + let elicitation_requests = self.requests.clone(); + Box::new(move |id, elicitation| { + let elicitation_requests = elicitation_requests.clone(); + let tx_event = tx_event.clone(); + let server_name = server_name.clone(); + async move { + let (tx, rx) = oneshot::channel(); + if let Ok(mut lock) = elicitation_requests.lock() { + lock.insert((server_name.clone(), id.clone()), tx); + } + let _ = tx_event + .send(Event { + id: "mcp_elicitation_request".to_string(), + msg: EventMsg::ElicitationRequest(ElicitationRequestEvent { + server_name, + id, + message: elicitation.message, + }), + }) + .await; + rx.await + .context("elicitation request channel closed unexpectedly") + } + .boxed() + }) + } +} + #[derive(Clone)] struct ManagedClient { client: Arc, @@ -129,19 +187,33 @@ impl AsyncManagedClient { config: McpServerConfig, store_mode: OAuthCredentialsStoreMode, cancel_token: CancellationToken, + tx_event: Sender, + elicitation_requests: ElicitationRequestManager, ) -> Self { let tool_filter = ToolFilter::from_config(&config); - let fut = start_server_task( - server_name, - config.transport, - store_mode, - config - .startup_timeout_sec - .unwrap_or(DEFAULT_STARTUP_TIMEOUT), - config.tool_timeout_sec.unwrap_or(DEFAULT_TOOL_TIMEOUT), - tool_filter, - cancel_token, - ); + let fut = async move { + if let Err(error) = validate_mcp_server_name(&server_name) { + return Err(error.into()); + } + + let client = + Arc::new(make_rmcp_client(&server_name, config.transport, store_mode).await?); + match start_server_task( + server_name, + client, + config.startup_timeout_sec.or(Some(DEFAULT_STARTUP_TIMEOUT)), + config.tool_timeout_sec.unwrap_or(DEFAULT_TOOL_TIMEOUT), + tool_filter, + tx_event, + elicitation_requests, + ) + .or_cancel(&cancel_token) + .await + { + Ok(result) => result, + Err(CancelErr::Cancelled) => Err(StartupOutcomeError::Cancelled), + } + }; Self { client: fut.boxed().shared(), } @@ -156,6 +228,7 @@ impl AsyncManagedClient { #[derive(Default)] pub(crate) struct McpConnectionManager { clients: HashMap, + elicitation_requests: ElicitationRequestManager, } impl McpConnectionManager { @@ -172,6 +245,7 @@ impl McpConnectionManager { } let mut clients = HashMap::new(); let mut join_set = JoinSet::new(); + let elicitation_requests = ElicitationRequestManager::default(); for (server_name, cfg) in mcp_servers.into_iter().filter(|(_, cfg)| cfg.enabled) { let cancel_token = cancel_token.child_token(); let _ = emit_update( @@ -182,8 +256,14 @@ impl McpConnectionManager { }, ) .await; - let async_managed_client = - AsyncManagedClient::new(server_name.clone(), cfg, store_mode, cancel_token.clone()); + let async_managed_client = AsyncManagedClient::new( + server_name.clone(), + cfg, + store_mode, + cancel_token.clone(), + tx_event.clone(), + elicitation_requests.clone(), + ); clients.insert(server_name.clone(), async_managed_client.clone()); let tx_event = tx_event.clone(); let auth_entry = auth_entries.get(&server_name).cloned(); @@ -217,6 +297,7 @@ impl McpConnectionManager { }); } self.clients = clients; + self.elicitation_requests = elicitation_requests.clone(); tokio::spawn(async move { let outcomes = join_set.join_all().await; let mut summary = McpStartupCompleteEvent::default(); @@ -250,6 +331,15 @@ impl McpConnectionManager { .context("failed to get client") } + pub fn resolve_elicitation( + &self, + server_name: String, + id: RequestId, + response: ElicitationResponse, + ) -> Result<()> { + self.elicitation_requests.resolve(server_name, id, response) + } + /// Returns a single map that contains all tools. Each key is the /// fully-qualified name for the tool. pub async fn list_all_tools(&self) -> HashMap { @@ -580,43 +670,12 @@ impl From for StartupOutcomeError { async fn start_server_task( server_name: String, - transport: McpServerTransportConfig, - store_mode: OAuthCredentialsStoreMode, - startup_timeout: Duration, // TODO: cancel_token should handle this. - tool_timeout: Duration, - tool_filter: ToolFilter, - cancel_token: CancellationToken, -) -> Result { - if cancel_token.is_cancelled() { - return Err(StartupOutcomeError::Cancelled); - } - if let Err(error) = validate_mcp_server_name(&server_name) { - return Err(error.into()); - } - - match start_server_work( - server_name, - transport, - store_mode, - startup_timeout, - tool_timeout, - tool_filter, - ) - .or_cancel(&cancel_token) - .await - { - Ok(result) => result, - Err(CancelErr::Cancelled) => Err(StartupOutcomeError::Cancelled), - } -} - -async fn start_server_work( - server_name: String, - transport: McpServerTransportConfig, - store_mode: OAuthCredentialsStoreMode, - startup_timeout: Duration, + client: Arc, + startup_timeout: Option, // TODO: cancel_token should handle this. tool_timeout: Duration, tool_filter: ToolFilter, + tx_event: Sender, + elicitation_requests: ElicitationRequestManager, ) -> Result { let params = mcp_types::InitializeRequestParams { capabilities: ClientCapabilities { @@ -639,7 +698,33 @@ async fn start_server_work( protocol_version: mcp_types::MCP_SCHEMA_VERSION.to_owned(), }; - let client_result = match transport { + let send_elicitation = elicitation_requests.make_sender(server_name.clone(), tx_event); + + client + .initialize(params, startup_timeout, send_elicitation) + .await + .map_err(StartupOutcomeError::from)?; + + let tools = list_tools_for_client(&server_name, &client, startup_timeout) + .await + .map_err(StartupOutcomeError::from)?; + + let managed = ManagedClient { + client: Arc::clone(&client), + tools, + tool_timeout: Some(tool_timeout), + tool_filter, + }; + + Ok(managed) +} + +async fn make_rmcp_client( + server_name: &str, + transport: McpServerTransportConfig, + store_mode: OAuthCredentialsStoreMode, +) -> Result { + match transport { McpServerTransportConfig::Stdio { command, args, @@ -649,16 +734,9 @@ async fn start_server_work( } => { let command_os: OsString = command.into(); let args_os: Vec = args.into_iter().map(Into::into).collect(); - match RmcpClient::new_stdio_client(command_os, args_os, env, &env_vars, cwd).await { - Ok(client) => { - let client = Arc::new(client); - client - .initialize(params.clone(), Some(startup_timeout)) - .await - .map(|_| client) - } - Err(err) => Err(err.into()), - } + RmcpClient::new_stdio_client(command_os, args_os, env, &env_vars, cwd) + .await + .map_err(|err| StartupOutcomeError::from(anyhow!(err))) } McpServerTransportConfig::StreamableHttp { url, @@ -667,12 +745,12 @@ async fn start_server_work( bearer_token_env_var, } => { let resolved_bearer_token = - match resolve_bearer_token(&server_name, bearer_token_env_var.as_deref()) { + match resolve_bearer_token(server_name, bearer_token_env_var.as_deref()) { Ok(token) => token, Err(error) => return Err(error.into()), }; - match RmcpClient::new_streamable_http_client( - &server_name, + RmcpClient::new_streamable_http_client( + server_name, &url, resolved_bearer_token, http_headers, @@ -680,49 +758,17 @@ async fn start_server_work( store_mode, ) .await - { - Ok(client) => { - let client = Arc::new(client); - client - .initialize(params.clone(), Some(startup_timeout)) - .await - .map(|_| client) - } - Err(err) => Err(err), - } - } - }; - - let client = match client_result { - Ok(client) => client, - Err(error) => { - return Err(error.into()); - } - }; - - let tools = match list_tools_for_client(&server_name, &client, startup_timeout).await { - Ok(tools) => tools, - Err(error) => { - return Err(error.into()); + .map_err(StartupOutcomeError::from) } - }; - - let managed = ManagedClient { - client: Arc::clone(&client), - tools, - tool_timeout: Some(tool_timeout), - tool_filter, - }; - - Ok(managed) + } } async fn list_tools_for_client( server_name: &str, client: &Arc, - timeout: Duration, + timeout: Option, ) -> Result> { - let resp = client.list_tools(None, Some(timeout)).await?; + let resp = client.list_tools(None, timeout).await?; Ok(resp .tools .into_iter() diff --git a/codex-rs/core/src/rollout/policy.rs b/codex-rs/core/src/rollout/policy.rs index 9e0e308362..4d5f709d25 100644 --- a/codex-rs/core/src/rollout/policy.rs +++ b/codex-rs/core/src/rollout/policy.rs @@ -64,6 +64,7 @@ pub(crate) fn should_persist_event_msg(ev: &EventMsg) -> bool { | EventMsg::ExecCommandOutputDelta(_) | EventMsg::ExecCommandEnd(_) | EventMsg::ExecApprovalRequest(_) + | EventMsg::ElicitationRequest(_) | EventMsg::ApplyPatchApprovalRequest(_) | EventMsg::BackgroundEvent(_) | EventMsg::StreamError(_) diff --git a/codex-rs/exec/src/event_processor_with_human_output.rs b/codex-rs/exec/src/event_processor_with_human_output.rs index 2d550fea46..e28b726cab 100644 --- a/codex-rs/exec/src/event_processor_with_human_output.rs +++ b/codex-rs/exec/src/event_processor_with_human_output.rs @@ -227,6 +227,19 @@ impl EventProcessor for EventProcessorWithHumanOutput { EventMsg::TaskStarted(_) => { // Ignore. } + EventMsg::ElicitationRequest(ev) => { + ts_msg!( + self, + "{} {}", + "elicitation request".style(self.magenta), + ev.server_name.style(self.dimmed) + ); + ts_msg!( + self, + "{}", + "auto-cancelling (not supported in exec mode)".style(self.dimmed) + ); + } EventMsg::TaskComplete(TaskCompleteEvent { last_agent_message }) => { let last_message = last_agent_message.as_deref(); if let Some(output_file) = self.last_message_path.as_deref() { diff --git a/codex-rs/exec/src/lib.rs b/codex-rs/exec/src/lib.rs index a003b4ff21..eb013b8280 100644 --- a/codex-rs/exec/src/lib.rs +++ b/codex-rs/exec/src/lib.rs @@ -30,6 +30,7 @@ use codex_core::protocol::Event; use codex_core::protocol::EventMsg; use codex_core::protocol::Op; use codex_core::protocol::SessionSource; +use codex_protocol::approvals::ElicitationAction; use codex_protocol::config_types::SandboxMode; use codex_protocol::user_input::UserInput; use event_processor_with_human_output::EventProcessorWithHumanOutput; @@ -401,6 +402,16 @@ pub async fn run_main(cli: Cli, codex_linux_sandbox_exe: Option) -> any // exit with a non-zero status for automation-friendly signaling. let mut error_seen = false; while let Some(event) = rx.recv().await { + if let EventMsg::ElicitationRequest(ev) = &event.msg { + // Automatically cancel elicitation requests in exec mode. + conversation + .submit(Op::ResolveElicitation { + server_name: ev.server_name.clone(), + request_id: ev.id.clone(), + decision: ElicitationAction::Cancel, + }) + .await?; + } if matches!(event.msg, EventMsg::Error(_)) { error_seen = true; } diff --git a/codex-rs/mcp-server/src/codex_tool_runner.rs b/codex-rs/mcp-server/src/codex_tool_runner.rs index 93dc7764dc..2eee3b853e 100644 --- a/codex-rs/mcp-server/src/codex_tool_runner.rs +++ b/codex-rs/mcp-server/src/codex_tool_runner.rs @@ -208,6 +208,10 @@ async fn run_codex_tool_session_inner( EventMsg::Warning(_) => { continue; } + EventMsg::ElicitationRequest(_) => { + // TODO: forward elicitation requests to the client? + continue; + } EventMsg::ApplyPatchApprovalRequest(ApplyPatchApprovalRequestEvent { call_id, reason, diff --git a/codex-rs/protocol/src/approvals.rs b/codex-rs/protocol/src/approvals.rs index f7c5fc6049..17d6c08734 100644 --- a/codex-rs/protocol/src/approvals.rs +++ b/codex-rs/protocol/src/approvals.rs @@ -3,6 +3,7 @@ use std::path::PathBuf; use crate::parse_command::ParsedCommand; use crate::protocol::FileChange; +use mcp_types::RequestId; use schemars::JsonSchema; use serde::Deserialize; use serde::Serialize; @@ -53,6 +54,24 @@ pub struct ExecApprovalRequestEvent { pub parsed_cmd: Vec, } +#[derive(Debug, Clone, Deserialize, Serialize, JsonSchema, TS)] +pub struct ElicitationRequestEvent { + pub server_name: String, + pub id: RequestId, + pub message: String, + // TODO: MCP servers can request we fill out a schema for the elicitation. We don't support + // this yet. + // pub requested_schema: ElicitRequestParamsRequestedSchema, +} + +#[derive(Debug, Clone, Copy, Deserialize, Serialize, PartialEq, Eq, JsonSchema, TS)] +#[serde(rename_all = "lowercase")] +pub enum ElicitationAction { + Accept, + Decline, + Cancel, +} + #[derive(Debug, Clone, Deserialize, Serialize, JsonSchema, TS)] pub struct ApplyPatchApprovalRequestEvent { /// Responses API call id for the associated patch apply call, if available. diff --git a/codex-rs/protocol/src/protocol.rs b/codex-rs/protocol/src/protocol.rs index e3bc76199a..8b328f93ef 100644 --- a/codex-rs/protocol/src/protocol.rs +++ b/codex-rs/protocol/src/protocol.rs @@ -11,6 +11,7 @@ use std::str::FromStr; use std::time::Duration; use crate::ConversationId; +use crate::approvals::ElicitationRequestEvent; use crate::config_types::ReasoningEffort as ReasoningEffortConfig; use crate::config_types::ReasoningSummary as ReasoningSummaryConfig; use crate::custom_prompts::CustomPrompt; @@ -23,6 +24,7 @@ use crate::parse_command::ParsedCommand; use crate::plan_tool::UpdatePlanArgs; use crate::user_input::UserInput; use mcp_types::CallToolResult; +use mcp_types::RequestId; use mcp_types::Resource as McpResource; use mcp_types::ResourceTemplate as McpResourceTemplate; use mcp_types::Tool as McpTool; @@ -35,6 +37,7 @@ use strum_macros::Display; use ts_rs::TS; pub use crate::approvals::ApplyPatchApprovalRequestEvent; +pub use crate::approvals::ElicitationAction; pub use crate::approvals::ExecApprovalRequestEvent; pub use crate::approvals::SandboxCommandAssessment; pub use crate::approvals::SandboxRiskLevel; @@ -153,6 +156,16 @@ pub enum Op { decision: ReviewDecision, }, + /// Resolve an MCP elicitation request. + ResolveElicitation { + /// Name of the MCP server that issued the request. + server_name: String, + /// Request identifier from the MCP server. + request_id: RequestId, + /// User's decision for the request. + decision: ElicitationAction, + }, + /// Append an entry to the persistent cross-session message history. /// /// Note the entry is not guaranteed to be logged if the user has @@ -505,6 +518,8 @@ pub enum EventMsg { ExecApprovalRequest(ExecApprovalRequestEvent), + ElicitationRequest(ElicitationRequestEvent), + ApplyPatchApprovalRequest(ApplyPatchApprovalRequestEvent), /// Notification advising the user that something they are using has been diff --git a/codex-rs/rmcp-client/src/lib.rs b/codex-rs/rmcp-client/src/lib.rs index 87ce86b464..ac617f3d29 100644 --- a/codex-rs/rmcp-client/src/lib.rs +++ b/codex-rs/rmcp-client/src/lib.rs @@ -17,4 +17,8 @@ pub use oauth::delete_oauth_tokens; pub(crate) use oauth::load_oauth_tokens; pub use oauth::save_oauth_tokens; pub use perform_oauth_login::perform_oauth_login; +pub use rmcp::model::ElicitationAction; +pub use rmcp_client::Elicitation; +pub use rmcp_client::ElicitationResponse; pub use rmcp_client::RmcpClient; +pub use rmcp_client::SendElicitation; diff --git a/codex-rs/rmcp-client/src/logging_client_handler.rs b/codex-rs/rmcp-client/src/logging_client_handler.rs index 85d237b0e9..0d2c3aaa97 100644 --- a/codex-rs/rmcp-client/src/logging_client_handler.rs +++ b/codex-rs/rmcp-client/src/logging_client_handler.rs @@ -1,13 +1,15 @@ +use std::sync::Arc; + use rmcp::ClientHandler; use rmcp::RoleClient; use rmcp::model::CancelledNotificationParam; use rmcp::model::ClientInfo; use rmcp::model::CreateElicitationRequestParam; use rmcp::model::CreateElicitationResult; -use rmcp::model::ElicitationAction; use rmcp::model::LoggingLevel; use rmcp::model::LoggingMessageNotificationParam; use rmcp::model::ProgressNotificationParam; +use rmcp::model::RequestId; use rmcp::model::ResourceUpdatedNotificationParam; use rmcp::service::NotificationContext; use rmcp::service::RequestContext; @@ -16,32 +18,36 @@ use tracing::error; use tracing::info; use tracing::warn; -#[derive(Debug, Clone)] +use crate::rmcp_client::SendElicitation; + +#[derive(Clone)] pub(crate) struct LoggingClientHandler { client_info: ClientInfo, + send_elicitation: Arc, } impl LoggingClientHandler { - pub(crate) fn new(client_info: ClientInfo) -> Self { - Self { client_info } + pub(crate) fn new(client_info: ClientInfo, send_elicitation: SendElicitation) -> Self { + Self { + client_info, + send_elicitation: Arc::new(send_elicitation), + } } } impl ClientHandler for LoggingClientHandler { - // TODO (CODEX-3571): support elicitations. async fn create_elicitation( &self, request: CreateElicitationRequestParam, - _context: RequestContext, + context: RequestContext, ) -> Result { - info!( - "MCP server requested elicitation ({}). Elicitations are not supported yet. Declining.", - request.message - ); - Ok(CreateElicitationResult { - action: ElicitationAction::Decline, - content: None, - }) + let id = match context.id { + RequestId::String(id) => mcp_types::RequestId::String(id.to_string()), + RequestId::Number(id) => mcp_types::RequestId::Integer(id), + }; + (self.send_elicitation)(id, request) + .await + .map_err(|err| rmcp::ErrorData::internal_error(err.to_string(), None)) } async fn on_cancelled( diff --git a/codex-rs/rmcp-client/src/rmcp_client.rs b/codex-rs/rmcp-client/src/rmcp_client.rs index d7d3477b00..fe9f48d04e 100644 --- a/codex-rs/rmcp-client/src/rmcp_client.rs +++ b/codex-rs/rmcp-client/src/rmcp_client.rs @@ -9,6 +9,7 @@ use std::time::Duration; use anyhow::Result; use anyhow::anyhow; use futures::FutureExt; +use futures::future::BoxFuture; use mcp_types::CallToolRequestParams; use mcp_types::CallToolResult; use mcp_types::InitializeRequestParams; @@ -21,8 +22,11 @@ use mcp_types::ListToolsRequestParams; use mcp_types::ListToolsResult; use mcp_types::ReadResourceRequestParams; use mcp_types::ReadResourceResult; +use mcp_types::RequestId; use reqwest::header::HeaderMap; use rmcp::model::CallToolRequestParam; +use rmcp::model::CreateElicitationRequestParam; +use rmcp::model::CreateElicitationResult; use rmcp::model::InitializeRequestParam; use rmcp::model::PaginatedRequestParam; use rmcp::model::ReadResourceRequestParam; @@ -77,6 +81,14 @@ enum ClientState { }, } +pub type Elicitation = CreateElicitationRequestParam; +pub type ElicitationResponse = CreateElicitationResult; + +/// Interface for sending elicitation requests to the UI and awaiting a response. +pub type SendElicitation = Box< + dyn Fn(RequestId, Elicitation) -> BoxFuture<'static, Result> + Send + Sync, +>; + /// MCP client implemented on top of the official `rmcp` SDK. /// https://github.com/modelcontextprotocol/rust-sdk pub struct RmcpClient { @@ -200,9 +212,10 @@ impl RmcpClient { &self, params: InitializeRequestParams, timeout: Option, + send_elicitation: SendElicitation, ) -> Result { let rmcp_params: InitializeRequestParam = convert_to_rmcp(params.clone())?; - let client_handler = LoggingClientHandler::new(rmcp_params); + let client_handler = LoggingClientHandler::new(rmcp_params, send_elicitation); let (transport, oauth_persistor) = { let mut guard = self.state.lock().await; diff --git a/codex-rs/rmcp-client/tests/resources.rs b/codex-rs/rmcp-client/tests/resources.rs index 2117f9b14c..fda21d14e2 100644 --- a/codex-rs/rmcp-client/tests/resources.rs +++ b/codex-rs/rmcp-client/tests/resources.rs @@ -2,8 +2,11 @@ use std::ffi::OsString; use std::path::PathBuf; use std::time::Duration; +use codex_rmcp_client::ElicitationAction; +use codex_rmcp_client::ElicitationResponse; use codex_rmcp_client::RmcpClient; use escargot::CargoBuild; +use futures::FutureExt as _; use mcp_types::ClientCapabilities; use mcp_types::Implementation; use mcp_types::InitializeRequestParams; @@ -55,7 +58,19 @@ async fn rmcp_client_can_list_and_read_resources() -> anyhow::Result<()> { .await?; client - .initialize(init_params(), Some(Duration::from_secs(5))) + .initialize( + init_params(), + Some(Duration::from_secs(5)), + Box::new(|_, _| { + async { + Ok(ElicitationResponse { + action: ElicitationAction::Accept, + content: Some(json!({})), + }) + } + .boxed() + }), + ) .await?; let list = client diff --git a/codex-rs/tui/src/app.rs b/codex-rs/tui/src/app.rs index 7c86dd3b6e..6d3feca229 100644 --- a/codex-rs/tui/src/app.rs +++ b/codex-rs/tui/src/app.rs @@ -44,6 +44,8 @@ use crossterm::event::KeyEvent; use crossterm::event::KeyEventKind; use ratatui::style::Stylize; use ratatui::text::Line; +use ratatui::widgets::Paragraph; +use ratatui::widgets::Wrap; use std::path::PathBuf; use std::sync::Arc; use std::sync::atomic::AtomicBool; @@ -821,6 +823,23 @@ impl App { "E X E C".to_string(), )); } + ApprovalRequest::McpElicitation { + server_name, + message, + .. + } => { + let _ = tui.enter_alt_screen(); + let paragraph = Paragraph::new(vec![ + Line::from(vec!["Server: ".into(), server_name.bold()]), + Line::from(""), + Line::from(message), + ]) + .wrap(Wrap { trim: false }); + self.overlay = Some(Overlay::new_static_with_renderables( + vec![Box::new(paragraph)], + "E L I C I T A T I O N".to_string(), + )); + } }, } Ok(true) diff --git a/codex-rs/tui/src/bottom_pane/approval_overlay.rs b/codex-rs/tui/src/bottom_pane/approval_overlay.rs index ef709f0051..c6006a9ae7 100644 --- a/codex-rs/tui/src/bottom_pane/approval_overlay.rs +++ b/codex-rs/tui/src/bottom_pane/approval_overlay.rs @@ -16,6 +16,7 @@ use crate::key_hint::KeyBinding; use crate::render::highlight::highlight_bash_to_lines; use crate::render::renderable::ColumnRenderable; use crate::render::renderable::Renderable; +use codex_core::protocol::ElicitationAction; use codex_core::protocol::FileChange; use codex_core::protocol::Op; use codex_core::protocol::ReviewDecision; @@ -25,6 +26,7 @@ use crossterm::event::KeyCode; use crossterm::event::KeyEvent; use crossterm::event::KeyEventKind; use crossterm::event::KeyModifiers; +use mcp_types::RequestId; use ratatui::buffer::Buffer; use ratatui::layout::Rect; use ratatui::style::Stylize; @@ -48,6 +50,11 @@ pub(crate) enum ApprovalRequest { cwd: PathBuf, changes: HashMap, }, + McpElicitation { + server_name: String, + request_id: RequestId, + message: String, + }, } /// Modal overlay asking the user to approve or deny one or more requests. @@ -105,6 +112,10 @@ impl ApprovalOverlay { patch_options(), "Would you like to make the following edits?".to_string(), ), + ApprovalVariant::McpElicitation { server_name, .. } => ( + elicitation_options(), + format!("{server_name} needs your approval."), + ), }; let header = Box::new(ColumnRenderable::with([ @@ -149,13 +160,23 @@ impl ApprovalOverlay { return; }; if let Some(variant) = self.current_variant.as_ref() { - match (&variant, option.decision) { - (ApprovalVariant::Exec { id, command }, decision) => { - self.handle_exec_decision(id, command, decision); + match (&variant, &option.decision) { + (ApprovalVariant::Exec { id, command }, ApprovalDecision::Review(decision)) => { + self.handle_exec_decision(id, command, *decision); + } + (ApprovalVariant::ApplyPatch { id, .. }, ApprovalDecision::Review(decision)) => { + self.handle_patch_decision(id, *decision); } - (ApprovalVariant::ApplyPatch { id, .. }, decision) => { - self.handle_patch_decision(id, decision); + ( + ApprovalVariant::McpElicitation { + server_name, + request_id, + }, + ApprovalDecision::McpElicitation(decision), + ) => { + self.handle_elicitation_decision(server_name, request_id, *decision); } + _ => {} } } @@ -179,6 +200,20 @@ impl ApprovalOverlay { })); } + fn handle_elicitation_decision( + &self, + server_name: &str, + request_id: &RequestId, + decision: ElicitationAction, + ) { + self.app_event_tx + .send(AppEvent::CodexOp(Op::ResolveElicitation { + server_name: server_name.to_string(), + request_id: request_id.clone(), + decision, + })); + } + fn advance_queue(&mut self) { if let Some(next) = self.queue.pop() { self.set_current(next); @@ -244,6 +279,16 @@ impl BottomPaneView for ApprovalOverlay { ApprovalVariant::ApplyPatch { id, .. } => { self.handle_patch_decision(id, ReviewDecision::Abort); } + ApprovalVariant::McpElicitation { + server_name, + request_id, + } => { + self.handle_elicitation_decision( + server_name, + request_id, + ElicitationAction::Cancel, + ); + } } } self.queue.clear(); @@ -336,6 +381,25 @@ impl From for ApprovalRequestState { header: Box::new(ColumnRenderable::with(header)), } } + ApprovalRequest::McpElicitation { + server_name, + request_id, + message, + } => { + let header = Paragraph::new(vec![ + Line::from(vec!["Server: ".into(), server_name.clone().bold()]), + Line::from(""), + Line::from(message), + ]) + .wrap(Wrap { trim: false }); + Self { + variant: ApprovalVariant::McpElicitation { + server_name, + request_id, + }, + header: Box::new(header), + } + } } } } @@ -364,14 +428,29 @@ fn render_risk_lines(risk: &SandboxCommandAssessment) -> Vec> { #[derive(Clone)] enum ApprovalVariant { - Exec { id: String, command: Vec }, - ApplyPatch { id: String }, + Exec { + id: String, + command: Vec, + }, + ApplyPatch { + id: String, + }, + McpElicitation { + server_name: String, + request_id: RequestId, + }, +} + +#[derive(Clone)] +enum ApprovalDecision { + Review(ReviewDecision), + McpElicitation(ElicitationAction), } #[derive(Clone)] struct ApprovalOption { label: String, - decision: ReviewDecision, + decision: ApprovalDecision, display_shortcut: Option, additional_shortcuts: Vec, } @@ -388,19 +467,19 @@ fn exec_options() -> Vec { vec![ ApprovalOption { label: "Yes, proceed".to_string(), - decision: ReviewDecision::Approved, + decision: ApprovalDecision::Review(ReviewDecision::Approved), display_shortcut: None, additional_shortcuts: vec![key_hint::plain(KeyCode::Char('y'))], }, ApprovalOption { label: "Yes, and don't ask again for this command".to_string(), - decision: ReviewDecision::ApprovedForSession, + decision: ApprovalDecision::Review(ReviewDecision::ApprovedForSession), display_shortcut: None, additional_shortcuts: vec![key_hint::plain(KeyCode::Char('a'))], }, ApprovalOption { label: "No, and tell Codex what to do differently".to_string(), - decision: ReviewDecision::Abort, + decision: ApprovalDecision::Review(ReviewDecision::Abort), display_shortcut: Some(key_hint::plain(KeyCode::Esc)), additional_shortcuts: vec![key_hint::plain(KeyCode::Char('n'))], }, @@ -411,19 +490,42 @@ fn patch_options() -> Vec { vec![ ApprovalOption { label: "Yes, proceed".to_string(), - decision: ReviewDecision::Approved, + decision: ApprovalDecision::Review(ReviewDecision::Approved), display_shortcut: None, additional_shortcuts: vec![key_hint::plain(KeyCode::Char('y'))], }, ApprovalOption { label: "No, and tell Codex what to do differently".to_string(), - decision: ReviewDecision::Abort, + decision: ApprovalDecision::Review(ReviewDecision::Abort), display_shortcut: Some(key_hint::plain(KeyCode::Esc)), additional_shortcuts: vec![key_hint::plain(KeyCode::Char('n'))], }, ] } +fn elicitation_options() -> Vec { + vec![ + ApprovalOption { + label: "Yes, provide the requested info".to_string(), + decision: ApprovalDecision::McpElicitation(ElicitationAction::Accept), + display_shortcut: None, + additional_shortcuts: vec![key_hint::plain(KeyCode::Char('y'))], + }, + ApprovalOption { + label: "No, but continue without it".to_string(), + decision: ApprovalDecision::McpElicitation(ElicitationAction::Decline), + display_shortcut: None, + additional_shortcuts: vec![key_hint::plain(KeyCode::Char('n'))], + }, + ApprovalOption { + label: "Cancel this request".to_string(), + decision: ApprovalDecision::McpElicitation(ElicitationAction::Cancel), + display_shortcut: Some(key_hint::plain(KeyCode::Esc)), + additional_shortcuts: vec![key_hint::plain(KeyCode::Char('c'))], + }, + ] +} + #[cfg(test)] mod tests { use super::*; diff --git a/codex-rs/tui/src/chatwidget.rs b/codex-rs/tui/src/chatwidget.rs index 9429ce143e..a9560e7242 100644 --- a/codex-rs/tui/src/chatwidget.rs +++ b/codex-rs/tui/src/chatwidget.rs @@ -53,6 +53,7 @@ use codex_core::protocol::WarningEvent; use codex_core::protocol::WebSearchBeginEvent; use codex_core::protocol::WebSearchEndEvent; use codex_protocol::ConversationId; +use codex_protocol::approvals::ElicitationRequestEvent; use codex_protocol::parse_command::ParsedCommand; use codex_protocol::user_input::UserInput; use crossterm::event::KeyCode; @@ -708,6 +709,14 @@ impl ChatWidget { ); } + fn on_elicitation_request(&mut self, ev: ElicitationRequestEvent) { + let ev2 = ev.clone(); + self.defer_or_handle( + |q| q.push_elicitation(ev), + |s| s.handle_elicitation_request_now(ev2), + ); + } + fn on_exec_command_begin(&mut self, ev: ExecCommandBeginEvent) { self.flush_answer_stream_with_separator(); let ev2 = ev.clone(); @@ -1013,6 +1022,22 @@ impl ChatWidget { }); } + pub(crate) fn handle_elicitation_request_now(&mut self, ev: ElicitationRequestEvent) { + self.flush_answer_stream_with_separator(); + + self.notify(Notification::ElicitationRequested { + server_name: ev.server_name.clone(), + }); + + let request = ApprovalRequest::McpElicitation { + server_name: ev.server_name, + request_id: ev.id, + message: ev.message, + }; + self.bottom_pane.push_approval_request(request); + self.request_redraw(); + } + pub(crate) fn handle_exec_begin_now(&mut self, ev: ExecCommandBeginEvent) { // Ensure the status indicator is visible while the command runs. self.running_commands.insert( @@ -1649,6 +1674,9 @@ impl ChatWidget { EventMsg::ApplyPatchApprovalRequest(ev) => { self.on_apply_patch_approval_request(id.unwrap_or_default(), ev) } + EventMsg::ElicitationRequest(ev) => { + self.on_elicitation_request(ev); + } EventMsg::ExecCommandBegin(ev) => self.on_exec_command_begin(ev), EventMsg::ExecCommandOutputDelta(delta) => self.on_exec_command_output_delta(delta), EventMsg::PatchApplyBegin(ev) => self.on_patch_apply_begin(ev), @@ -2938,6 +2966,7 @@ enum Notification { AgentTurnComplete { response: String }, ExecApprovalRequested { command: String }, EditApprovalRequested { cwd: PathBuf, changes: Vec }, + ElicitationRequested { server_name: String }, } impl Notification { @@ -2961,6 +2990,9 @@ impl Notification { } ) } + Notification::ElicitationRequested { server_name } => { + format!("Approval requested by {server_name}") + } } } @@ -2968,7 +3000,8 @@ impl Notification { match self { Notification::AgentTurnComplete { .. } => "agent-turn-complete", Notification::ExecApprovalRequested { .. } - | Notification::EditApprovalRequested { .. } => "approval-requested", + | Notification::EditApprovalRequested { .. } + | Notification::ElicitationRequested { .. } => "approval-requested", } } diff --git a/codex-rs/tui/src/chatwidget/interrupts.rs b/codex-rs/tui/src/chatwidget/interrupts.rs index 531de3e646..dc1e683ea5 100644 --- a/codex-rs/tui/src/chatwidget/interrupts.rs +++ b/codex-rs/tui/src/chatwidget/interrupts.rs @@ -7,6 +7,7 @@ use codex_core::protocol::ExecCommandEndEvent; use codex_core::protocol::McpToolCallBeginEvent; use codex_core::protocol::McpToolCallEndEvent; use codex_core::protocol::PatchApplyEndEvent; +use codex_protocol::approvals::ElicitationRequestEvent; use super::ChatWidget; @@ -14,6 +15,7 @@ use super::ChatWidget; pub(crate) enum QueuedInterrupt { ExecApproval(String, ExecApprovalRequestEvent), ApplyPatchApproval(String, ApplyPatchApprovalRequestEvent), + Elicitation(ElicitationRequestEvent), ExecBegin(ExecCommandBeginEvent), ExecEnd(ExecCommandEndEvent), McpBegin(McpToolCallBeginEvent), @@ -51,6 +53,10 @@ impl InterruptManager { .push_back(QueuedInterrupt::ApplyPatchApproval(id, ev)); } + pub(crate) fn push_elicitation(&mut self, ev: ElicitationRequestEvent) { + self.queue.push_back(QueuedInterrupt::Elicitation(ev)); + } + pub(crate) fn push_exec_begin(&mut self, ev: ExecCommandBeginEvent) { self.queue.push_back(QueuedInterrupt::ExecBegin(ev)); } @@ -78,6 +84,7 @@ impl InterruptManager { QueuedInterrupt::ApplyPatchApproval(id, ev) => { chat.handle_apply_patch_approval_now(id, ev) } + QueuedInterrupt::Elicitation(ev) => chat.handle_elicitation_request_now(ev), QueuedInterrupt::ExecBegin(ev) => chat.handle_exec_begin_now(ev), QueuedInterrupt::ExecEnd(ev) => chat.handle_exec_end_now(ev), QueuedInterrupt::McpBegin(ev) => chat.handle_mcp_begin_now(ev),