diff --git a/codex-rs/core/src/tools/handlers/agent_jobs.rs b/codex-rs/core/src/tools/handlers/agent_jobs.rs index 93380d5e4b96..4c3cd34c2eae 100644 --- a/codex-rs/core/src/tools/handlers/agent_jobs.rs +++ b/codex-rs/core/src/tools/handlers/agent_jobs.rs @@ -6,20 +6,14 @@ use crate::config::Config; use crate::function_tool::FunctionCallError; use crate::session::session::Session; use crate::session::turn_context::TurnContext; -use crate::tools::context::FunctionToolOutput; -use crate::tools::context::ToolInvocation; -use crate::tools::context::ToolPayload; use crate::tools::handlers::multi_agents::build_agent_spawn_config; use crate::tools::handlers::parse_arguments; -use crate::tools::registry::ToolHandler; -use crate::tools::registry::ToolKind; use codex_protocol::ThreadId; use codex_protocol::error::CodexErr; use codex_protocol::protocol::AgentStatus; use codex_protocol::protocol::SessionSource; use codex_protocol::protocol::SubAgentSource; use codex_protocol::user_input::UserInput; -use codex_tools::ToolName; use codex_utils_absolute_path::AbsolutePathBuf; use futures::StreamExt; use futures::stream::FuturesUnordered; @@ -36,8 +30,11 @@ use tokio::time::Instant; use tokio::time::timeout; use uuid::Uuid; -pub struct SpawnAgentsOnCsvHandler; -pub struct ReportAgentJobResultHandler; +mod report_agent_job_result; +mod spawn_agents_on_csv; + +pub use report_agent_job_result::ReportAgentJobResultHandler; +pub use spawn_agents_on_csv::SpawnAgentsOnCsvHandler; const DEFAULT_AGENT_JOB_CONCURRENCY: usize = 16; const MAX_AGENT_JOB_CONCURRENCY: usize = 64; @@ -101,364 +98,6 @@ struct ActiveJobItem { status_rx: Option>, } -impl ToolHandler for SpawnAgentsOnCsvHandler { - type Output = FunctionToolOutput; - - fn tool_name(&self) -> ToolName { - ToolName::plain("spawn_agents_on_csv") - } - - fn kind(&self) -> ToolKind { - ToolKind::Function - } - - fn matches_kind(&self, payload: &ToolPayload) -> bool { - matches!(payload, ToolPayload::Function { .. }) - } - - async fn handle(&self, invocation: ToolInvocation) -> Result { - let ToolInvocation { - session, - turn, - payload, - .. - } = invocation; - - let arguments = match payload { - ToolPayload::Function { arguments } => arguments, - _ => { - return Err(FunctionCallError::RespondToModel( - "agent jobs handler received unsupported payload".to_string(), - )); - } - }; - - spawn_agents_on_csv::handle(session, turn, arguments).await - } -} - -impl ToolHandler for ReportAgentJobResultHandler { - type Output = FunctionToolOutput; - - fn tool_name(&self) -> ToolName { - ToolName::plain("report_agent_job_result") - } - - fn kind(&self) -> ToolKind { - ToolKind::Function - } - - fn matches_kind(&self, payload: &ToolPayload) -> bool { - matches!(payload, ToolPayload::Function { .. }) - } - - async fn handle(&self, invocation: ToolInvocation) -> Result { - let ToolInvocation { - session, payload, .. - } = invocation; - - let arguments = match payload { - ToolPayload::Function { arguments } => arguments, - _ => { - return Err(FunctionCallError::RespondToModel( - "report_agent_job_result handler received unsupported payload".to_string(), - )); - } - }; - - report_agent_job_result::handle(session, arguments).await - } -} - -mod spawn_agents_on_csv { - use super::*; - - /// Create a new agent job from a CSV and run it to completion. - /// - /// Each CSV row becomes a job item. The instruction string is a template where `{column}` - /// placeholders are filled with values from that row. Results are reported by workers via - /// `report_agent_job_result`, then exported to CSV on completion. - pub async fn handle( - session: Arc, - turn: Arc, - arguments: String, - ) -> Result { - let args: SpawnAgentsOnCsvArgs = parse_arguments(arguments.as_str())?; - if args.instruction.trim().is_empty() { - return Err(FunctionCallError::RespondToModel( - "instruction must be non-empty".to_string(), - )); - } - - let db = required_state_db(&session)?; - let input_path = turn.resolve_path(Some(args.csv_path)); - let input_path_display = input_path.display().to_string(); - let csv_content = tokio::fs::read_to_string(&input_path) - .await - .map_err(|err| { - FunctionCallError::RespondToModel(format!( - "failed to read csv input {input_path_display}: {err}" - )) - })?; - let (headers, rows) = parse_csv(csv_content.as_str()).map_err(|err| { - FunctionCallError::RespondToModel(format!("failed to parse csv input: {err}")) - })?; - if headers.is_empty() { - return Err(FunctionCallError::RespondToModel( - "csv input must include a header row".to_string(), - )); - } - ensure_unique_headers(headers.as_slice())?; - - let id_column_index = args.id_column.as_ref().map_or(Ok(None), |column_name| { - headers - .iter() - .position(|header| header == column_name) - .map(Some) - .ok_or_else(|| { - FunctionCallError::RespondToModel(format!( - "id_column {column_name} was not found in csv headers" - )) - }) - })?; - - let mut items = Vec::with_capacity(rows.len()); - let mut seen_ids = HashSet::new(); - for (idx, row) in rows.into_iter().enumerate() { - if row.len() != headers.len() { - let row_index = idx + 2; - let row_len = row.len(); - let header_len = headers.len(); - return Err(FunctionCallError::RespondToModel(format!( - "csv row {row_index} has {row_len} fields but header has {header_len}" - ))); - } - - let source_id = id_column_index - .and_then(|index| row.get(index).cloned()) - .filter(|value| !value.trim().is_empty()); - let row_index = idx + 1; - let base_item_id = source_id - .clone() - .unwrap_or_else(|| format!("row-{row_index}")); - let mut item_id = base_item_id.clone(); - let mut suffix = 2usize; - while !seen_ids.insert(item_id.clone()) { - item_id = format!("{base_item_id}-{suffix}"); - suffix = suffix.saturating_add(1); - } - - let row_object = headers - .iter() - .zip(row.iter()) - .map(|(header, value)| (header.clone(), Value::String(value.clone()))) - .collect::>(); - items.push(codex_state::AgentJobItemCreateParams { - item_id, - row_index: idx as i64, - source_id, - row_json: Value::Object(row_object), - }); - } - - let job_id = Uuid::new_v4().to_string(); - let output_csv_path = args.output_csv_path.map_or_else( - || default_output_csv_path(&input_path, job_id.as_str()), - |path| turn.resolve_path(Some(path)), - ); - let job_suffix = &job_id[..8]; - let job_name = format!("agent-job-{job_suffix}"); - let max_runtime_seconds = normalize_max_runtime_seconds( - args.max_runtime_seconds - .or(turn.config.agent_job_max_runtime_seconds), - )?; - let _job = db - .create_agent_job( - &codex_state::AgentJobCreateParams { - id: job_id.clone(), - name: job_name, - instruction: args.instruction, - auto_export: true, - max_runtime_seconds, - output_schema_json: args.output_schema, - input_headers: headers, - input_csv_path: input_path.display().to_string(), - output_csv_path: output_csv_path.display().to_string(), - }, - items.as_slice(), - ) - .await - .map_err(|err| { - FunctionCallError::RespondToModel(format!("failed to create agent job: {err}")) - })?; - - let requested_concurrency = args.max_concurrency.or(args.max_workers); - let options = match build_runner_options(&session, &turn, requested_concurrency).await { - Ok(options) => options, - Err(err) => { - let error_message = err.to_string(); - let _ = db - .mark_agent_job_failed(job_id.as_str(), error_message.as_str()) - .await; - return Err(err); - } - }; - db.mark_agent_job_running(job_id.as_str()) - .await - .map_err(|err| { - FunctionCallError::RespondToModel(format!( - "failed to transition agent job {job_id} to running: {err}" - )) - })?; - if let Err(err) = run_agent_job_loop( - session.clone(), - turn.clone(), - db.clone(), - job_id.clone(), - options, - ) - .await - { - let error_message = format!("job runner failed: {err}"); - let _ = db - .mark_agent_job_failed(job_id.as_str(), error_message.as_str()) - .await; - return Err(FunctionCallError::RespondToModel(format!( - "agent job {job_id} failed: {err}" - ))); - } - - let job = db - .get_agent_job(job_id.as_str()) - .await - .map_err(|err| { - FunctionCallError::RespondToModel(format!( - "failed to load agent job {job_id}: {err}" - )) - })? - .ok_or_else(|| { - FunctionCallError::RespondToModel(format!("agent job {job_id} not found")) - })?; - let output_path = PathBuf::from(job.output_csv_path.clone()); - if !tokio::fs::try_exists(&output_path).await.unwrap_or(false) { - export_job_csv_snapshot(db.clone(), &job) - .await - .map_err(|err| { - FunctionCallError::RespondToModel(format!( - "failed to export output csv {job_id}: {err}" - )) - })?; - } - let progress = db - .get_agent_job_progress(job_id.as_str()) - .await - .map_err(|err| { - FunctionCallError::RespondToModel(format!( - "failed to load agent job progress {job_id}: {err}" - )) - })?; - let mut job_error = job.last_error.clone().filter(|err| !err.trim().is_empty()); - let failed_item_errors = if progress.failed_items > 0 { - let items = db - .list_agent_job_items( - job_id.as_str(), - Some(codex_state::AgentJobItemStatus::Failed), - Some(5), - ) - .await - .unwrap_or_default(); - let summaries: Vec<_> = items - .into_iter() - .filter_map(|item| { - let last_error = item.last_error.unwrap_or_default(); - if last_error.trim().is_empty() { - return None; - } - Some(AgentJobFailureSummary { - item_id: item.item_id, - source_id: item.source_id, - last_error, - }) - }) - .collect(); - if summaries.is_empty() { - if job_error.is_none() { - job_error = Some( - "agent job has failed items but no error details were recorded".to_string(), - ); - } - None - } else { - Some(summaries) - } - } else { - None - }; - let content = serde_json::to_string(&SpawnAgentsOnCsvResult { - job_id, - status: job.status.as_str().to_string(), - output_csv_path: job.output_csv_path, - total_items: progress.total_items, - completed_items: progress.completed_items, - failed_items: progress.failed_items, - job_error, - failed_item_errors, - }) - .map_err(|err| { - FunctionCallError::Fatal(format!( - "failed to serialize spawn_agents_on_csv result: {err}" - )) - })?; - Ok(FunctionToolOutput::from_text(content, Some(true))) - } -} - -mod report_agent_job_result { - use super::*; - - pub async fn handle( - session: Arc, - arguments: String, - ) -> Result { - let args: ReportAgentJobResultArgs = parse_arguments(arguments.as_str())?; - if !args.result.is_object() { - return Err(FunctionCallError::RespondToModel( - "result must be a JSON object".to_string(), - )); - } - let db = required_state_db(&session)?; - let reporting_thread_id = session.conversation_id.to_string(); - let accepted = db - .report_agent_job_item_result( - args.job_id.as_str(), - args.item_id.as_str(), - reporting_thread_id.as_str(), - &args.result, - ) - .await - .map_err(|err| { - let job_id = args.job_id.as_str(); - let item_id = args.item_id.as_str(); - FunctionCallError::RespondToModel(format!( - "failed to record agent job result for {job_id} / {item_id}: {err}" - )) - })?; - if accepted && args.stop.unwrap_or(false) { - let message = "cancelled by worker request"; - let _ = db - .mark_agent_job_cancelled(args.job_id.as_str(), message) - .await; - } - let content = - serde_json::to_string(&ReportAgentJobResultToolResult { accepted }).map_err(|err| { - FunctionCallError::Fatal(format!( - "failed to serialize report_agent_job_result result: {err}" - )) - })?; - Ok(FunctionToolOutput::from_text(content, Some(true))) - } -} - fn required_state_db( session: &Arc, ) -> Result, FunctionCallError> { diff --git a/codex-rs/core/src/tools/handlers/agent_jobs/report_agent_job_result.rs b/codex-rs/core/src/tools/handlers/agent_jobs/report_agent_job_result.rs new file mode 100644 index 000000000000..90cde7d44638 --- /dev/null +++ b/codex-rs/core/src/tools/handlers/agent_jobs/report_agent_job_result.rs @@ -0,0 +1,86 @@ +use crate::function_tool::FunctionCallError; +use crate::tools::context::FunctionToolOutput; +use crate::tools::context::ToolInvocation; +use crate::tools::context::ToolPayload; +use crate::tools::registry::ToolHandler; +use crate::tools::registry::ToolKind; +use codex_tools::ToolName; + +use super::*; + +pub struct ReportAgentJobResultHandler; + +impl ToolHandler for ReportAgentJobResultHandler { + type Output = FunctionToolOutput; + + fn tool_name(&self) -> ToolName { + ToolName::plain("report_agent_job_result") + } + + fn kind(&self) -> ToolKind { + ToolKind::Function + } + + fn matches_kind(&self, payload: &ToolPayload) -> bool { + matches!(payload, ToolPayload::Function { .. }) + } + + async fn handle(&self, invocation: ToolInvocation) -> Result { + let ToolInvocation { + session, payload, .. + } = invocation; + + let arguments = match payload { + ToolPayload::Function { arguments } => arguments, + _ => { + return Err(FunctionCallError::RespondToModel( + "report_agent_job_result handler received unsupported payload".to_string(), + )); + } + }; + + handle(session, arguments).await + } +} + +pub async fn handle( + session: Arc, + arguments: String, +) -> Result { + let args: ReportAgentJobResultArgs = parse_arguments(arguments.as_str())?; + if !args.result.is_object() { + return Err(FunctionCallError::RespondToModel( + "result must be a JSON object".to_string(), + )); + } + let db = required_state_db(&session)?; + let reporting_thread_id = session.conversation_id.to_string(); + let accepted = db + .report_agent_job_item_result( + args.job_id.as_str(), + args.item_id.as_str(), + reporting_thread_id.as_str(), + &args.result, + ) + .await + .map_err(|err| { + let job_id = args.job_id.as_str(); + let item_id = args.item_id.as_str(); + FunctionCallError::RespondToModel(format!( + "failed to record agent job result for {job_id} / {item_id}: {err}" + )) + })?; + if accepted && args.stop.unwrap_or(false) { + let message = "cancelled by worker request"; + let _ = db + .mark_agent_job_cancelled(args.job_id.as_str(), message) + .await; + } + let content = + serde_json::to_string(&ReportAgentJobResultToolResult { accepted }).map_err(|err| { + FunctionCallError::Fatal(format!( + "failed to serialize report_agent_job_result result: {err}" + )) + })?; + Ok(FunctionToolOutput::from_text(content, Some(true))) +} diff --git a/codex-rs/core/src/tools/handlers/agent_jobs/spawn_agents_on_csv.rs b/codex-rs/core/src/tools/handlers/agent_jobs/spawn_agents_on_csv.rs new file mode 100644 index 000000000000..911f1a5eef5b --- /dev/null +++ b/codex-rs/core/src/tools/handlers/agent_jobs/spawn_agents_on_csv.rs @@ -0,0 +1,284 @@ +use crate::function_tool::FunctionCallError; +use crate::tools::context::FunctionToolOutput; +use crate::tools::context::ToolInvocation; +use crate::tools::context::ToolPayload; +use crate::tools::registry::ToolHandler; +use crate::tools::registry::ToolKind; +use codex_tools::ToolName; + +use super::*; + +pub struct SpawnAgentsOnCsvHandler; + +impl ToolHandler for SpawnAgentsOnCsvHandler { + type Output = FunctionToolOutput; + + fn tool_name(&self) -> ToolName { + ToolName::plain("spawn_agents_on_csv") + } + + fn kind(&self) -> ToolKind { + ToolKind::Function + } + + fn matches_kind(&self, payload: &ToolPayload) -> bool { + matches!(payload, ToolPayload::Function { .. }) + } + + async fn handle(&self, invocation: ToolInvocation) -> Result { + let ToolInvocation { + session, + turn, + payload, + .. + } = invocation; + + let arguments = match payload { + ToolPayload::Function { arguments } => arguments, + _ => { + return Err(FunctionCallError::RespondToModel( + "agent jobs handler received unsupported payload".to_string(), + )); + } + }; + + handle(session, turn, arguments).await + } +} + +/// Create a new agent job from a CSV and run it to completion. +/// +/// Each CSV row becomes a job item. The instruction string is a template where `{column}` +/// placeholders are filled with values from that row. Results are reported by workers via +/// `report_agent_job_result`, then exported to CSV on completion. +pub async fn handle( + session: Arc, + turn: Arc, + arguments: String, +) -> Result { + let args: SpawnAgentsOnCsvArgs = parse_arguments(arguments.as_str())?; + if args.instruction.trim().is_empty() { + return Err(FunctionCallError::RespondToModel( + "instruction must be non-empty".to_string(), + )); + } + + let db = required_state_db(&session)?; + let input_path = turn.resolve_path(Some(args.csv_path)); + let input_path_display = input_path.display().to_string(); + let csv_content = tokio::fs::read_to_string(&input_path) + .await + .map_err(|err| { + FunctionCallError::RespondToModel(format!( + "failed to read csv input {input_path_display}: {err}" + )) + })?; + let (headers, rows) = parse_csv(csv_content.as_str()).map_err(|err| { + FunctionCallError::RespondToModel(format!("failed to parse csv input: {err}")) + })?; + if headers.is_empty() { + return Err(FunctionCallError::RespondToModel( + "csv input must include a header row".to_string(), + )); + } + ensure_unique_headers(headers.as_slice())?; + + let id_column_index = args.id_column.as_ref().map_or(Ok(None), |column_name| { + headers + .iter() + .position(|header| header == column_name) + .map(Some) + .ok_or_else(|| { + FunctionCallError::RespondToModel(format!( + "id_column {column_name} was not found in csv headers" + )) + }) + })?; + + let mut items = Vec::with_capacity(rows.len()); + let mut seen_ids = HashSet::new(); + for (idx, row) in rows.into_iter().enumerate() { + if row.len() != headers.len() { + let row_index = idx + 2; + let row_len = row.len(); + let header_len = headers.len(); + return Err(FunctionCallError::RespondToModel(format!( + "csv row {row_index} has {row_len} fields but header has {header_len}" + ))); + } + + let source_id = id_column_index + .and_then(|index| row.get(index).cloned()) + .filter(|value| !value.trim().is_empty()); + let row_index = idx + 1; + let base_item_id = source_id + .clone() + .unwrap_or_else(|| format!("row-{row_index}")); + let mut item_id = base_item_id.clone(); + let mut suffix = 2usize; + while !seen_ids.insert(item_id.clone()) { + item_id = format!("{base_item_id}-{suffix}"); + suffix = suffix.saturating_add(1); + } + + let row_object = headers + .iter() + .zip(row.iter()) + .map(|(header, value)| (header.clone(), Value::String(value.clone()))) + .collect::>(); + items.push(codex_state::AgentJobItemCreateParams { + item_id, + row_index: idx as i64, + source_id, + row_json: Value::Object(row_object), + }); + } + + let job_id = Uuid::new_v4().to_string(); + let output_csv_path = args.output_csv_path.map_or_else( + || default_output_csv_path(&input_path, job_id.as_str()), + |path| turn.resolve_path(Some(path)), + ); + let job_suffix = &job_id[..8]; + let job_name = format!("agent-job-{job_suffix}"); + let max_runtime_seconds = normalize_max_runtime_seconds( + args.max_runtime_seconds + .or(turn.config.agent_job_max_runtime_seconds), + )?; + let _job = db + .create_agent_job( + &codex_state::AgentJobCreateParams { + id: job_id.clone(), + name: job_name, + instruction: args.instruction, + auto_export: true, + max_runtime_seconds, + output_schema_json: args.output_schema, + input_headers: headers, + input_csv_path: input_path.display().to_string(), + output_csv_path: output_csv_path.display().to_string(), + }, + items.as_slice(), + ) + .await + .map_err(|err| { + FunctionCallError::RespondToModel(format!("failed to create agent job: {err}")) + })?; + + let requested_concurrency = args.max_concurrency.or(args.max_workers); + let options = match build_runner_options(&session, &turn, requested_concurrency).await { + Ok(options) => options, + Err(err) => { + let error_message = err.to_string(); + let _ = db + .mark_agent_job_failed(job_id.as_str(), error_message.as_str()) + .await; + return Err(err); + } + }; + db.mark_agent_job_running(job_id.as_str()) + .await + .map_err(|err| { + FunctionCallError::RespondToModel(format!( + "failed to transition agent job {job_id} to running: {err}" + )) + })?; + if let Err(err) = run_agent_job_loop( + session.clone(), + turn.clone(), + db.clone(), + job_id.clone(), + options, + ) + .await + { + let error_message = format!("job runner failed: {err}"); + let _ = db + .mark_agent_job_failed(job_id.as_str(), error_message.as_str()) + .await; + return Err(FunctionCallError::RespondToModel(format!( + "agent job {job_id} failed: {err}" + ))); + } + + let job = db + .get_agent_job(job_id.as_str()) + .await + .map_err(|err| { + FunctionCallError::RespondToModel(format!("failed to load agent job {job_id}: {err}")) + })? + .ok_or_else(|| { + FunctionCallError::RespondToModel(format!("agent job {job_id} not found")) + })?; + let output_path = PathBuf::from(job.output_csv_path.clone()); + if !tokio::fs::try_exists(&output_path).await.unwrap_or(false) { + export_job_csv_snapshot(db.clone(), &job) + .await + .map_err(|err| { + FunctionCallError::RespondToModel(format!( + "failed to export output csv {job_id}: {err}" + )) + })?; + } + let progress = db + .get_agent_job_progress(job_id.as_str()) + .await + .map_err(|err| { + FunctionCallError::RespondToModel(format!( + "failed to load agent job progress {job_id}: {err}" + )) + })?; + let mut job_error = job.last_error.clone().filter(|err| !err.trim().is_empty()); + let failed_item_errors = if progress.failed_items > 0 { + let items = db + .list_agent_job_items( + job_id.as_str(), + Some(codex_state::AgentJobItemStatus::Failed), + Some(5), + ) + .await + .unwrap_or_default(); + let summaries: Vec<_> = items + .into_iter() + .filter_map(|item| { + let last_error = item.last_error.unwrap_or_default(); + if last_error.trim().is_empty() { + return None; + } + Some(AgentJobFailureSummary { + item_id: item.item_id, + source_id: item.source_id, + last_error, + }) + }) + .collect(); + if summaries.is_empty() { + if job_error.is_none() { + job_error = Some( + "agent job has failed items but no error details were recorded".to_string(), + ); + } + None + } else { + Some(summaries) + } + } else { + None + }; + let content = serde_json::to_string(&SpawnAgentsOnCsvResult { + job_id, + status: job.status.as_str().to_string(), + output_csv_path: job.output_csv_path, + total_items: progress.total_items, + completed_items: progress.completed_items, + failed_items: progress.failed_items, + job_error, + failed_item_errors, + }) + .map_err(|err| { + FunctionCallError::Fatal(format!( + "failed to serialize spawn_agents_on_csv result: {err}" + )) + })?; + Ok(FunctionToolOutput::from_text(content, Some(true))) +} diff --git a/codex-rs/core/src/tools/handlers/goal.rs b/codex-rs/core/src/tools/handlers/goal.rs index 6a7b304ce428..28e33f2be40b 100644 --- a/codex-rs/core/src/tools/handlers/goal.rs +++ b/codex-rs/core/src/tools/handlers/goal.rs @@ -5,28 +5,20 @@ //! the existing goal complete. use crate::function_tool::FunctionCallError; -use crate::goals::CreateGoalRequest; -use crate::goals::GoalRuntimeEvent; -use crate::goals::SetGoalRequest; use crate::tools::context::FunctionToolOutput; -use crate::tools::context::ToolInvocation; -use crate::tools::context::ToolPayload; -use crate::tools::handlers::parse_arguments; -use crate::tools::registry::ToolHandler; -use crate::tools::registry::ToolKind; use codex_protocol::protocol::ThreadGoal; use codex_protocol::protocol::ThreadGoalStatus; -use codex_tools::CREATE_GOAL_TOOL_NAME; -use codex_tools::GET_GOAL_TOOL_NAME; -use codex_tools::ToolName; -use codex_tools::UPDATE_GOAL_TOOL_NAME; use serde::Deserialize; use serde::Serialize; use std::fmt::Write as _; -pub struct GetGoalHandler; -pub struct CreateGoalHandler; -pub struct UpdateGoalHandler; +mod create_goal; +mod get_goal; +mod update_goal; + +pub use create_goal::CreateGoalHandler; +pub use get_goal::GetGoalHandler; +pub use update_goal::UpdateGoalHandler; #[derive(Debug, Deserialize)] #[serde(rename_all = "snake_case")] @@ -76,148 +68,6 @@ impl GoalToolResponse { } } -impl ToolHandler for GetGoalHandler { - type Output = FunctionToolOutput; - - fn tool_name(&self) -> ToolName { - ToolName::plain(GET_GOAL_TOOL_NAME) - } - - fn kind(&self) -> ToolKind { - ToolKind::Function - } - - async fn handle(&self, invocation: ToolInvocation) -> Result { - let ToolInvocation { - session, payload, .. - } = invocation; - - match payload { - ToolPayload::Function { .. } => { - let goal = session - .get_thread_goal() - .await - .map_err(|err| FunctionCallError::RespondToModel(format_goal_error(err)))?; - goal_response(goal, CompletionBudgetReport::Omit) - } - _ => Err(FunctionCallError::RespondToModel( - "get_goal handler received unsupported payload".to_string(), - )), - } - } -} - -impl ToolHandler for CreateGoalHandler { - type Output = FunctionToolOutput; - - fn tool_name(&self) -> ToolName { - ToolName::plain(CREATE_GOAL_TOOL_NAME) - } - - fn kind(&self) -> ToolKind { - ToolKind::Function - } - - async fn handle(&self, invocation: ToolInvocation) -> Result { - let ToolInvocation { - session, - turn, - payload, - .. - } = invocation; - - let arguments = match payload { - ToolPayload::Function { arguments } => arguments, - _ => { - return Err(FunctionCallError::RespondToModel( - "goal handler received unsupported payload".to_string(), - )); - } - }; - - let args: CreateGoalArgs = parse_arguments(&arguments)?; - let goal = session - .create_thread_goal( - turn.as_ref(), - CreateGoalRequest { - objective: args.objective, - token_budget: args.token_budget, - }, - ) - .await - .map_err(|err| { - if err - .chain() - .any(|cause| cause.to_string().contains("already has a goal")) - { - FunctionCallError::RespondToModel( - "cannot create a new goal because this thread already has a goal; use update_goal only when the existing goal is complete" - .to_string(), - ) - } else { - FunctionCallError::RespondToModel(format_goal_error(err)) - } - })?; - goal_response(Some(goal), CompletionBudgetReport::Omit) - } -} - -impl ToolHandler for UpdateGoalHandler { - type Output = FunctionToolOutput; - - fn tool_name(&self) -> ToolName { - ToolName::plain(UPDATE_GOAL_TOOL_NAME) - } - - fn kind(&self) -> ToolKind { - ToolKind::Function - } - - async fn handle(&self, invocation: ToolInvocation) -> Result { - let ToolInvocation { - session, - turn, - payload, - .. - } = invocation; - - let arguments = match payload { - ToolPayload::Function { arguments } => arguments, - _ => { - return Err(FunctionCallError::RespondToModel( - "update_goal handler received unsupported payload".to_string(), - )); - } - }; - - let args: UpdateGoalArgs = parse_arguments(&arguments)?; - if args.status != ThreadGoalStatus::Complete { - return Err(FunctionCallError::RespondToModel( - "update_goal can only mark the existing goal complete; pause, resume, and budget-limited status changes are controlled by the user or system" - .to_string(), - )); - } - session - .goal_runtime_apply(GoalRuntimeEvent::ToolCompletedGoal { - turn_context: turn.as_ref(), - }) - .await - .map_err(|err| FunctionCallError::RespondToModel(format_goal_error(err)))?; - let goal = session - .set_thread_goal( - turn.as_ref(), - SetGoalRequest { - objective: None, - status: Some(ThreadGoalStatus::Complete), - token_budget: None, - }, - ) - .await - .map_err(|err| FunctionCallError::RespondToModel(format_goal_error(err)))?; - goal_response(Some(goal), CompletionBudgetReport::Include) - } -} - fn format_goal_error(err: anyhow::Error) -> String { let mut message = err.to_string(); for cause in err.chain().skip(1) { diff --git a/codex-rs/core/src/tools/handlers/goal/create_goal.rs b/codex-rs/core/src/tools/handlers/goal/create_goal.rs new file mode 100644 index 000000000000..88297cc1afe7 --- /dev/null +++ b/codex-rs/core/src/tools/handlers/goal/create_goal.rs @@ -0,0 +1,72 @@ +use crate::function_tool::FunctionCallError; +use crate::goals::CreateGoalRequest; +use crate::tools::context::FunctionToolOutput; +use crate::tools::context::ToolInvocation; +use crate::tools::context::ToolPayload; +use crate::tools::handlers::parse_arguments; +use crate::tools::registry::ToolHandler; +use crate::tools::registry::ToolKind; +use codex_tools::CREATE_GOAL_TOOL_NAME; +use codex_tools::ToolName; + +use super::CompletionBudgetReport; +use super::CreateGoalArgs; +use super::format_goal_error; +use super::goal_response; + +pub struct CreateGoalHandler; + +impl ToolHandler for CreateGoalHandler { + type Output = FunctionToolOutput; + + fn tool_name(&self) -> ToolName { + ToolName::plain(CREATE_GOAL_TOOL_NAME) + } + + fn kind(&self) -> ToolKind { + ToolKind::Function + } + + async fn handle(&self, invocation: ToolInvocation) -> Result { + let ToolInvocation { + session, + turn, + payload, + .. + } = invocation; + + let arguments = match payload { + ToolPayload::Function { arguments } => arguments, + _ => { + return Err(FunctionCallError::RespondToModel( + "goal handler received unsupported payload".to_string(), + )); + } + }; + + let args: CreateGoalArgs = parse_arguments(&arguments)?; + let goal = session + .create_thread_goal( + turn.as_ref(), + CreateGoalRequest { + objective: args.objective, + token_budget: args.token_budget, + }, + ) + .await + .map_err(|err| { + if err + .chain() + .any(|cause| cause.to_string().contains("already has a goal")) + { + FunctionCallError::RespondToModel( + "cannot create a new goal because this thread already has a goal; use update_goal only when the existing goal is complete" + .to_string(), + ) + } else { + FunctionCallError::RespondToModel(format_goal_error(err)) + } + })?; + goal_response(Some(goal), CompletionBudgetReport::Omit) + } +} diff --git a/codex-rs/core/src/tools/handlers/goal/get_goal.rs b/codex-rs/core/src/tools/handlers/goal/get_goal.rs new file mode 100644 index 000000000000..ab023f301452 --- /dev/null +++ b/codex-rs/core/src/tools/handlers/goal/get_goal.rs @@ -0,0 +1,45 @@ +use crate::function_tool::FunctionCallError; +use crate::tools::context::FunctionToolOutput; +use crate::tools::context::ToolInvocation; +use crate::tools::context::ToolPayload; +use crate::tools::registry::ToolHandler; +use crate::tools::registry::ToolKind; +use codex_tools::GET_GOAL_TOOL_NAME; +use codex_tools::ToolName; + +use super::CompletionBudgetReport; +use super::format_goal_error; +use super::goal_response; + +pub struct GetGoalHandler; + +impl ToolHandler for GetGoalHandler { + type Output = FunctionToolOutput; + + fn tool_name(&self) -> ToolName { + ToolName::plain(GET_GOAL_TOOL_NAME) + } + + fn kind(&self) -> ToolKind { + ToolKind::Function + } + + async fn handle(&self, invocation: ToolInvocation) -> Result { + let ToolInvocation { + session, payload, .. + } = invocation; + + match payload { + ToolPayload::Function { .. } => { + let goal = session + .get_thread_goal() + .await + .map_err(|err| FunctionCallError::RespondToModel(format_goal_error(err)))?; + goal_response(goal, CompletionBudgetReport::Omit) + } + _ => Err(FunctionCallError::RespondToModel( + "get_goal handler received unsupported payload".to_string(), + )), + } + } +} diff --git a/codex-rs/core/src/tools/handlers/goal/update_goal.rs b/codex-rs/core/src/tools/handlers/goal/update_goal.rs new file mode 100644 index 000000000000..6c43484ec912 --- /dev/null +++ b/codex-rs/core/src/tools/handlers/goal/update_goal.rs @@ -0,0 +1,75 @@ +use crate::function_tool::FunctionCallError; +use crate::goals::GoalRuntimeEvent; +use crate::goals::SetGoalRequest; +use crate::tools::context::FunctionToolOutput; +use crate::tools::context::ToolInvocation; +use crate::tools::context::ToolPayload; +use crate::tools::handlers::parse_arguments; +use crate::tools::registry::ToolHandler; +use crate::tools::registry::ToolKind; +use codex_protocol::protocol::ThreadGoalStatus; +use codex_tools::ToolName; +use codex_tools::UPDATE_GOAL_TOOL_NAME; + +use super::CompletionBudgetReport; +use super::UpdateGoalArgs; +use super::format_goal_error; +use super::goal_response; + +pub struct UpdateGoalHandler; + +impl ToolHandler for UpdateGoalHandler { + type Output = FunctionToolOutput; + + fn tool_name(&self) -> ToolName { + ToolName::plain(UPDATE_GOAL_TOOL_NAME) + } + + fn kind(&self) -> ToolKind { + ToolKind::Function + } + + async fn handle(&self, invocation: ToolInvocation) -> Result { + let ToolInvocation { + session, + turn, + payload, + .. + } = invocation; + + let arguments = match payload { + ToolPayload::Function { arguments } => arguments, + _ => { + return Err(FunctionCallError::RespondToModel( + "update_goal handler received unsupported payload".to_string(), + )); + } + }; + + let args: UpdateGoalArgs = parse_arguments(&arguments)?; + if args.status != ThreadGoalStatus::Complete { + return Err(FunctionCallError::RespondToModel( + "update_goal can only mark the existing goal complete; pause, resume, and budget-limited status changes are controlled by the user or system" + .to_string(), + )); + } + session + .goal_runtime_apply(GoalRuntimeEvent::ToolCompletedGoal { + turn_context: turn.as_ref(), + }) + .await + .map_err(|err| FunctionCallError::RespondToModel(format_goal_error(err)))?; + let goal = session + .set_thread_goal( + turn.as_ref(), + SetGoalRequest { + objective: None, + status: Some(ThreadGoalStatus::Complete), + token_budget: None, + }, + ) + .await + .map_err(|err| FunctionCallError::RespondToModel(format_goal_error(err)))?; + goal_response(Some(goal), CompletionBudgetReport::Include) + } +} diff --git a/codex-rs/core/src/tools/handlers/mcp_resource.rs b/codex-rs/core/src/tools/handlers/mcp_resource.rs index b03de53ae1ff..630a1cccb4e2 100644 --- a/codex-rs/core/src/tools/handlers/mcp_resource.rs +++ b/codex-rs/core/src/tools/handlers/mcp_resource.rs @@ -1,18 +1,14 @@ use std::collections::HashMap; use std::sync::Arc; use std::time::Duration; -use std::time::Instant; use codex_protocol::items::McpToolCallError; use codex_protocol::items::McpToolCallItem; use codex_protocol::items::McpToolCallStatus; use codex_protocol::items::TurnItem; use codex_protocol::mcp::CallToolResult; -use codex_protocol::models::function_call_output_content_items_to_text; use rmcp::model::ListResourceTemplatesResult; use rmcp::model::ListResourcesResult; -use rmcp::model::PaginatedRequestParams; -use rmcp::model::ReadResourceRequestParams; use rmcp::model::ReadResourceResult; use rmcp::model::Resource; use rmcp::model::ResourceTemplate; @@ -25,16 +21,15 @@ use crate::function_tool::FunctionCallError; use crate::session::session::Session; use crate::session::turn_context::TurnContext; use crate::tools::context::FunctionToolOutput; -use crate::tools::context::ToolInvocation; -use crate::tools::context::ToolPayload; -use crate::tools::registry::ToolHandler; -use crate::tools::registry::ToolKind; use codex_protocol::protocol::McpInvocation; -use codex_tools::ToolName; -pub struct ListMcpResourcesHandler; -pub struct ListMcpResourceTemplatesHandler; -pub struct ReadMcpResourceHandler; +mod list_mcp_resource_templates; +mod list_mcp_resources; +mod read_mcp_resource; + +pub use list_mcp_resource_templates::ListMcpResourceTemplatesHandler; +pub use list_mcp_resources::ListMcpResourcesHandler; +pub use read_mcp_resource::ReadMcpResourceHandler; #[derive(Debug, Deserialize, Default)] struct ListResourcesArgs { @@ -181,390 +176,6 @@ struct ReadResourcePayload { result: ReadResourceResult, } -impl ToolHandler for ListMcpResourcesHandler { - type Output = FunctionToolOutput; - - fn tool_name(&self) -> ToolName { - ToolName::plain("list_mcp_resources") - } - - fn kind(&self) -> ToolKind { - ToolKind::Function - } - - #[expect( - clippy::await_holding_invalid_type, - reason = "MCP resource listing reads through the session-owned manager guard" - )] - async fn handle(&self, invocation: ToolInvocation) -> Result { - let ToolInvocation { - session, - turn, - call_id, - payload, - .. - } = invocation; - - let arguments = match payload { - ToolPayload::Function { arguments } => arguments, - _ => { - return Err(FunctionCallError::RespondToModel( - "list_mcp_resources handler received unsupported payload".to_string(), - )); - } - }; - - let arguments = parse_arguments(arguments.as_str())?; - let args: ListResourcesArgs = parse_args_with_default(arguments.clone())?; - let ListResourcesArgs { server, cursor } = args; - let server = normalize_optional_string(server); - let cursor = normalize_optional_string(cursor); - - let invocation = McpInvocation { - server: server.clone().unwrap_or_else(|| "codex".to_string()), - tool: "list_mcp_resources".to_string(), - arguments: arguments.clone(), - }; - - emit_tool_call_begin(&session, turn.as_ref(), &call_id, invocation.clone()).await; - let start = Instant::now(); - - let payload_result: Result = async { - if let Some(server_name) = server.clone() { - let params = cursor.clone().map(|value| PaginatedRequestParams { - meta: None, - cursor: Some(value), - }); - let result = session - .list_resources(&server_name, params) - .await - .map_err(|err| { - FunctionCallError::RespondToModel(format!("resources/list failed: {err:#}")) - })?; - Ok(ListResourcesPayload::from_single_server( - server_name, - result, - )) - } else { - if cursor.is_some() { - return Err(FunctionCallError::RespondToModel( - "cursor can only be used when a server is specified".to_string(), - )); - } - - let resources = session - .services - .mcp_connection_manager - .read() - .await - .list_all_resources() - .await; - Ok(ListResourcesPayload::from_all_servers(resources)) - } - } - .await; - - match payload_result { - Ok(payload) => match serialize_function_output(payload) { - Ok(output) => { - let content = function_call_output_content_items_to_text(&output.body) - .unwrap_or_default(); - let duration = start.elapsed(); - emit_tool_call_end( - &session, - turn.as_ref(), - &call_id, - invocation, - duration, - Ok(call_tool_result_from_content(&content, output.success)), - ) - .await; - Ok(output) - } - Err(err) => { - let duration = start.elapsed(); - let message = err.to_string(); - emit_tool_call_end( - &session, - turn.as_ref(), - &call_id, - invocation, - duration, - Err(message.clone()), - ) - .await; - Err(err) - } - }, - Err(err) => { - let duration = start.elapsed(); - let message = err.to_string(); - emit_tool_call_end( - &session, - turn.as_ref(), - &call_id, - invocation, - duration, - Err(message.clone()), - ) - .await; - Err(err) - } - } - } -} - -impl ToolHandler for ListMcpResourceTemplatesHandler { - type Output = FunctionToolOutput; - - fn tool_name(&self) -> ToolName { - ToolName::plain("list_mcp_resource_templates") - } - - fn kind(&self) -> ToolKind { - ToolKind::Function - } - - #[expect( - clippy::await_holding_invalid_type, - reason = "MCP resource template listing reads through the session-owned manager guard" - )] - async fn handle(&self, invocation: ToolInvocation) -> Result { - let ToolInvocation { - session, - turn, - call_id, - payload, - .. - } = invocation; - - let arguments = match payload { - ToolPayload::Function { arguments } => arguments, - _ => { - return Err(FunctionCallError::RespondToModel( - "list_mcp_resource_templates handler received unsupported payload".to_string(), - )); - } - }; - - let arguments = parse_arguments(arguments.as_str())?; - let args: ListResourceTemplatesArgs = parse_args_with_default(arguments.clone())?; - let ListResourceTemplatesArgs { server, cursor } = args; - let server = normalize_optional_string(server); - let cursor = normalize_optional_string(cursor); - - let invocation = McpInvocation { - server: server.clone().unwrap_or_else(|| "codex".to_string()), - tool: "list_mcp_resource_templates".to_string(), - arguments: arguments.clone(), - }; - - emit_tool_call_begin(&session, turn.as_ref(), &call_id, invocation.clone()).await; - let start = Instant::now(); - - let payload_result: Result = async { - if let Some(server_name) = server.clone() { - let params = cursor.clone().map(|value| PaginatedRequestParams { - meta: None, - cursor: Some(value), - }); - let result = session - .list_resource_templates(&server_name, params) - .await - .map_err(|err| { - FunctionCallError::RespondToModel(format!( - "resources/templates/list failed: {err:#}" - )) - })?; - Ok(ListResourceTemplatesPayload::from_single_server( - server_name, - result, - )) - } else { - if cursor.is_some() { - return Err(FunctionCallError::RespondToModel( - "cursor can only be used when a server is specified".to_string(), - )); - } - - let templates = session - .services - .mcp_connection_manager - .read() - .await - .list_all_resource_templates() - .await; - Ok(ListResourceTemplatesPayload::from_all_servers(templates)) - } - } - .await; - - match payload_result { - Ok(payload) => match serialize_function_output(payload) { - Ok(output) => { - let content = function_call_output_content_items_to_text(&output.body) - .unwrap_or_default(); - let duration = start.elapsed(); - emit_tool_call_end( - &session, - turn.as_ref(), - &call_id, - invocation, - duration, - Ok(call_tool_result_from_content(&content, output.success)), - ) - .await; - Ok(output) - } - Err(err) => { - let duration = start.elapsed(); - let message = err.to_string(); - emit_tool_call_end( - &session, - turn.as_ref(), - &call_id, - invocation, - duration, - Err(message.clone()), - ) - .await; - Err(err) - } - }, - Err(err) => { - let duration = start.elapsed(); - let message = err.to_string(); - emit_tool_call_end( - &session, - turn.as_ref(), - &call_id, - invocation, - duration, - Err(message.clone()), - ) - .await; - Err(err) - } - } - } -} - -impl ToolHandler for ReadMcpResourceHandler { - type Output = FunctionToolOutput; - - fn tool_name(&self) -> ToolName { - ToolName::plain("read_mcp_resource") - } - - fn kind(&self) -> ToolKind { - ToolKind::Function - } - - async fn handle(&self, invocation: ToolInvocation) -> Result { - let ToolInvocation { - session, - turn, - call_id, - payload, - .. - } = invocation; - - let arguments = match payload { - ToolPayload::Function { arguments } => arguments, - _ => { - return Err(FunctionCallError::RespondToModel( - "read_mcp_resource handler received unsupported payload".to_string(), - )); - } - }; - - let arguments = parse_arguments(arguments.as_str())?; - let args: ReadResourceArgs = parse_args(arguments.clone())?; - let ReadResourceArgs { server, uri } = args; - let server = normalize_required_string("server", server)?; - let uri = normalize_required_string("uri", uri)?; - - let invocation = McpInvocation { - server: server.clone(), - tool: "read_mcp_resource".to_string(), - arguments: arguments.clone(), - }; - - emit_tool_call_begin(&session, turn.as_ref(), &call_id, invocation.clone()).await; - let start = Instant::now(); - - let payload_result: Result = async { - let result = session - .read_resource( - &server, - ReadResourceRequestParams { - meta: None, - uri: uri.clone(), - }, - ) - .await - .map_err(|err| { - FunctionCallError::RespondToModel(format!("resources/read failed: {err:#}")) - })?; - - Ok(ReadResourcePayload { - server, - uri, - result, - }) - } - .await; - - match payload_result { - Ok(payload) => match serialize_function_output(payload) { - Ok(output) => { - let content = function_call_output_content_items_to_text(&output.body) - .unwrap_or_default(); - let duration = start.elapsed(); - emit_tool_call_end( - &session, - turn.as_ref(), - &call_id, - invocation, - duration, - Ok(call_tool_result_from_content(&content, output.success)), - ) - .await; - Ok(output) - } - Err(err) => { - let duration = start.elapsed(); - let message = err.to_string(); - emit_tool_call_end( - &session, - turn.as_ref(), - &call_id, - invocation, - duration, - Err(message.clone()), - ) - .await; - Err(err) - } - }, - Err(err) => { - let duration = start.elapsed(); - let message = err.to_string(); - emit_tool_call_end( - &session, - turn.as_ref(), - &call_id, - invocation, - duration, - Err(message.clone()), - ) - .await; - Err(err) - } - } - } -} - fn call_tool_result_from_content(content: &str, success: Option) -> CallToolResult { CallToolResult { content: vec![serde_json::json!({"type": "text", "text": content})], diff --git a/codex-rs/core/src/tools/handlers/mcp_resource/list_mcp_resource_templates.rs b/codex-rs/core/src/tools/handlers/mcp_resource/list_mcp_resource_templates.rs new file mode 100644 index 000000000000..2c87edf0c78a --- /dev/null +++ b/codex-rs/core/src/tools/handlers/mcp_resource/list_mcp_resource_templates.rs @@ -0,0 +1,160 @@ +use std::time::Instant; + +use crate::function_tool::FunctionCallError; +use crate::tools::context::FunctionToolOutput; +use crate::tools::context::ToolInvocation; +use crate::tools::context::ToolPayload; +use crate::tools::registry::ToolHandler; +use crate::tools::registry::ToolKind; +use codex_protocol::models::function_call_output_content_items_to_text; +use codex_protocol::protocol::McpInvocation; +use codex_tools::ToolName; + +use rmcp::model::PaginatedRequestParams; + +use super::ListResourceTemplatesArgs; +use super::ListResourceTemplatesPayload; +use super::call_tool_result_from_content; +use super::emit_tool_call_begin; +use super::emit_tool_call_end; +use super::normalize_optional_string; +use super::parse_args_with_default; +use super::parse_arguments; +use super::serialize_function_output; + +pub struct ListMcpResourceTemplatesHandler; + +impl ToolHandler for ListMcpResourceTemplatesHandler { + type Output = FunctionToolOutput; + + fn tool_name(&self) -> ToolName { + ToolName::plain("list_mcp_resource_templates") + } + + fn kind(&self) -> ToolKind { + ToolKind::Function + } + + #[expect( + clippy::await_holding_invalid_type, + reason = "MCP resource template listing reads through the session-owned manager guard" + )] + async fn handle(&self, invocation: ToolInvocation) -> Result { + let ToolInvocation { + session, + turn, + call_id, + payload, + .. + } = invocation; + + let arguments = match payload { + ToolPayload::Function { arguments } => arguments, + _ => { + return Err(FunctionCallError::RespondToModel( + "list_mcp_resource_templates handler received unsupported payload".to_string(), + )); + } + }; + + let arguments = parse_arguments(arguments.as_str())?; + let args: ListResourceTemplatesArgs = parse_args_with_default(arguments.clone())?; + let ListResourceTemplatesArgs { server, cursor } = args; + let server = normalize_optional_string(server); + let cursor = normalize_optional_string(cursor); + + let invocation = McpInvocation { + server: server.clone().unwrap_or_else(|| "codex".to_string()), + tool: "list_mcp_resource_templates".to_string(), + arguments: arguments.clone(), + }; + + emit_tool_call_begin(&session, turn.as_ref(), &call_id, invocation.clone()).await; + let start = Instant::now(); + + let payload_result: Result = async { + if let Some(server_name) = server.clone() { + let params = cursor.clone().map(|value| PaginatedRequestParams { + meta: None, + cursor: Some(value), + }); + let result = session + .list_resource_templates(&server_name, params) + .await + .map_err(|err| { + FunctionCallError::RespondToModel(format!( + "resources/templates/list failed: {err:#}" + )) + })?; + Ok(ListResourceTemplatesPayload::from_single_server( + server_name, + result, + )) + } else { + if cursor.is_some() { + return Err(FunctionCallError::RespondToModel( + "cursor can only be used when a server is specified".to_string(), + )); + } + + let templates = session + .services + .mcp_connection_manager + .read() + .await + .list_all_resource_templates() + .await; + Ok(ListResourceTemplatesPayload::from_all_servers(templates)) + } + } + .await; + + match payload_result { + Ok(payload) => match serialize_function_output(payload) { + Ok(output) => { + let content = function_call_output_content_items_to_text(&output.body) + .unwrap_or_default(); + let duration = start.elapsed(); + emit_tool_call_end( + &session, + turn.as_ref(), + &call_id, + invocation, + duration, + Ok(call_tool_result_from_content(&content, output.success)), + ) + .await; + Ok(output) + } + Err(err) => { + let duration = start.elapsed(); + let message = err.to_string(); + emit_tool_call_end( + &session, + turn.as_ref(), + &call_id, + invocation, + duration, + Err(message.clone()), + ) + .await; + Err(err) + } + }, + Err(err) => { + let duration = start.elapsed(); + let message = err.to_string(); + emit_tool_call_end( + &session, + turn.as_ref(), + &call_id, + invocation, + duration, + Err(message.clone()), + ) + .await; + Err(err) + } + } + } +} diff --git a/codex-rs/core/src/tools/handlers/mcp_resource/list_mcp_resources.rs b/codex-rs/core/src/tools/handlers/mcp_resource/list_mcp_resources.rs new file mode 100644 index 000000000000..ed6285214116 --- /dev/null +++ b/codex-rs/core/src/tools/handlers/mcp_resource/list_mcp_resources.rs @@ -0,0 +1,158 @@ +use std::time::Instant; + +use crate::function_tool::FunctionCallError; +use crate::tools::context::FunctionToolOutput; +use crate::tools::context::ToolInvocation; +use crate::tools::context::ToolPayload; +use crate::tools::registry::ToolHandler; +use crate::tools::registry::ToolKind; +use codex_protocol::models::function_call_output_content_items_to_text; +use codex_protocol::protocol::McpInvocation; +use codex_tools::ToolName; + +use rmcp::model::PaginatedRequestParams; + +use super::ListResourcesArgs; +use super::ListResourcesPayload; +use super::call_tool_result_from_content; +use super::emit_tool_call_begin; +use super::emit_tool_call_end; +use super::normalize_optional_string; +use super::parse_args_with_default; +use super::parse_arguments; +use super::serialize_function_output; + +pub struct ListMcpResourcesHandler; + +impl ToolHandler for ListMcpResourcesHandler { + type Output = FunctionToolOutput; + + fn tool_name(&self) -> ToolName { + ToolName::plain("list_mcp_resources") + } + + fn kind(&self) -> ToolKind { + ToolKind::Function + } + + #[expect( + clippy::await_holding_invalid_type, + reason = "MCP resource listing reads through the session-owned manager guard" + )] + async fn handle(&self, invocation: ToolInvocation) -> Result { + let ToolInvocation { + session, + turn, + call_id, + payload, + .. + } = invocation; + + let arguments = match payload { + ToolPayload::Function { arguments } => arguments, + _ => { + return Err(FunctionCallError::RespondToModel( + "list_mcp_resources handler received unsupported payload".to_string(), + )); + } + }; + + let arguments = parse_arguments(arguments.as_str())?; + let args: ListResourcesArgs = parse_args_with_default(arguments.clone())?; + let ListResourcesArgs { server, cursor } = args; + let server = normalize_optional_string(server); + let cursor = normalize_optional_string(cursor); + + let invocation = McpInvocation { + server: server.clone().unwrap_or_else(|| "codex".to_string()), + tool: "list_mcp_resources".to_string(), + arguments: arguments.clone(), + }; + + emit_tool_call_begin(&session, turn.as_ref(), &call_id, invocation.clone()).await; + let start = Instant::now(); + + let payload_result: Result = async { + if let Some(server_name) = server.clone() { + let params = cursor.clone().map(|value| PaginatedRequestParams { + meta: None, + cursor: Some(value), + }); + let result = session + .list_resources(&server_name, params) + .await + .map_err(|err| { + FunctionCallError::RespondToModel(format!("resources/list failed: {err:#}")) + })?; + Ok(ListResourcesPayload::from_single_server( + server_name, + result, + )) + } else { + if cursor.is_some() { + return Err(FunctionCallError::RespondToModel( + "cursor can only be used when a server is specified".to_string(), + )); + } + + let resources = session + .services + .mcp_connection_manager + .read() + .await + .list_all_resources() + .await; + Ok(ListResourcesPayload::from_all_servers(resources)) + } + } + .await; + + match payload_result { + Ok(payload) => match serialize_function_output(payload) { + Ok(output) => { + let content = function_call_output_content_items_to_text(&output.body) + .unwrap_or_default(); + let duration = start.elapsed(); + emit_tool_call_end( + &session, + turn.as_ref(), + &call_id, + invocation, + duration, + Ok(call_tool_result_from_content(&content, output.success)), + ) + .await; + Ok(output) + } + Err(err) => { + let duration = start.elapsed(); + let message = err.to_string(); + emit_tool_call_end( + &session, + turn.as_ref(), + &call_id, + invocation, + duration, + Err(message.clone()), + ) + .await; + Err(err) + } + }, + Err(err) => { + let duration = start.elapsed(); + let message = err.to_string(); + emit_tool_call_end( + &session, + turn.as_ref(), + &call_id, + invocation, + duration, + Err(message.clone()), + ) + .await; + Err(err) + } + } + } +} diff --git a/codex-rs/core/src/tools/handlers/mcp_resource/read_mcp_resource.rs b/codex-rs/core/src/tools/handlers/mcp_resource/read_mcp_resource.rs new file mode 100644 index 000000000000..91d5a4317e50 --- /dev/null +++ b/codex-rs/core/src/tools/handlers/mcp_resource/read_mcp_resource.rs @@ -0,0 +1,141 @@ +use std::time::Instant; + +use crate::function_tool::FunctionCallError; +use crate::tools::context::FunctionToolOutput; +use crate::tools::context::ToolInvocation; +use crate::tools::context::ToolPayload; +use crate::tools::registry::ToolHandler; +use crate::tools::registry::ToolKind; +use codex_protocol::models::function_call_output_content_items_to_text; +use codex_protocol::protocol::McpInvocation; +use codex_tools::ToolName; + +use rmcp::model::ReadResourceRequestParams; + +use super::ReadResourceArgs; +use super::ReadResourcePayload; +use super::call_tool_result_from_content; +use super::emit_tool_call_begin; +use super::emit_tool_call_end; +use super::normalize_required_string; +use super::parse_args; +use super::parse_arguments; +use super::serialize_function_output; + +pub struct ReadMcpResourceHandler; + +impl ToolHandler for ReadMcpResourceHandler { + type Output = FunctionToolOutput; + + fn tool_name(&self) -> ToolName { + ToolName::plain("read_mcp_resource") + } + + fn kind(&self) -> ToolKind { + ToolKind::Function + } + + async fn handle(&self, invocation: ToolInvocation) -> Result { + let ToolInvocation { + session, + turn, + call_id, + payload, + .. + } = invocation; + + let arguments = match payload { + ToolPayload::Function { arguments } => arguments, + _ => { + return Err(FunctionCallError::RespondToModel( + "read_mcp_resource handler received unsupported payload".to_string(), + )); + } + }; + + let arguments = parse_arguments(arguments.as_str())?; + let args: ReadResourceArgs = parse_args(arguments.clone())?; + let ReadResourceArgs { server, uri } = args; + let server = normalize_required_string("server", server)?; + let uri = normalize_required_string("uri", uri)?; + + let invocation = McpInvocation { + server: server.clone(), + tool: "read_mcp_resource".to_string(), + arguments: arguments.clone(), + }; + + emit_tool_call_begin(&session, turn.as_ref(), &call_id, invocation.clone()).await; + let start = Instant::now(); + + let payload_result: Result = async { + let result = session + .read_resource( + &server, + ReadResourceRequestParams { + meta: None, + uri: uri.clone(), + }, + ) + .await + .map_err(|err| { + FunctionCallError::RespondToModel(format!("resources/read failed: {err:#}")) + })?; + + Ok(ReadResourcePayload { + server, + uri, + result, + }) + } + .await; + + match payload_result { + Ok(payload) => match serialize_function_output(payload) { + Ok(output) => { + let content = function_call_output_content_items_to_text(&output.body) + .unwrap_or_default(); + let duration = start.elapsed(); + emit_tool_call_end( + &session, + turn.as_ref(), + &call_id, + invocation, + duration, + Ok(call_tool_result_from_content(&content, output.success)), + ) + .await; + Ok(output) + } + Err(err) => { + let duration = start.elapsed(); + let message = err.to_string(); + emit_tool_call_end( + &session, + turn.as_ref(), + &call_id, + invocation, + duration, + Err(message.clone()), + ) + .await; + Err(err) + } + }, + Err(err) => { + let duration = start.elapsed(); + let message = err.to_string(); + emit_tool_call_end( + &session, + turn.as_ref(), + &call_id, + invocation, + duration, + Err(message.clone()), + ) + .await; + Err(err) + } + } + } +} diff --git a/codex-rs/core/src/tools/handlers/shell.rs b/codex-rs/core/src/tools/handlers/shell.rs index dfc3c3e5b9c5..469c0a0799f0 100644 --- a/codex-rs/core/src/tools/handlers/shell.rs +++ b/codex-rs/core/src/tools/handlers/shell.rs @@ -1,17 +1,13 @@ -use codex_protocol::ThreadId; +use codex_features::Feature; use codex_protocol::models::ShellCommandToolCallParams; use codex_protocol::models::ShellToolCallParams; use serde_json::Value as JsonValue; use std::sync::Arc; -use crate::exec::ExecCapturePolicy; use crate::exec::ExecParams; -use crate::exec_env::create_env; use crate::exec_policy::ExecApprovalRequest; use crate::function_tool::FunctionCallError; -use crate::maybe_emit_implicit_skill_invocation; use crate::session::turn_context::TurnContext; -use crate::shell::Shell; use crate::tools::context::FunctionToolOutput; use crate::tools::context::ToolInvocation; use crate::tools::context::ToolOutput; @@ -23,38 +19,26 @@ use crate::tools::handlers::apply_patch::intercept_apply_patch; use crate::tools::handlers::implicit_granted_permissions; use crate::tools::handlers::normalize_and_validate_additional_permissions; use crate::tools::handlers::parse_arguments; -use crate::tools::handlers::parse_arguments_with_base_path; -use crate::tools::handlers::resolve_workdir_base_path; use crate::tools::hook_names::HookToolName; use crate::tools::orchestrator::ToolOrchestrator; use crate::tools::registry::PostToolUsePayload; use crate::tools::registry::PreToolUsePayload; -use crate::tools::registry::ToolHandler; -use crate::tools::registry::ToolKind; use crate::tools::runtimes::shell::ShellRequest; use crate::tools::runtimes::shell::ShellRuntime; use crate::tools::runtimes::shell::ShellRuntimeBackend; use crate::tools::sandboxing::ToolCtx; -use codex_features::Feature; use codex_protocol::models::AdditionalPermissionProfile; use codex_protocol::protocol::ExecCommandSource; -use codex_shell_command::is_safe_command::is_known_safe_command; -use codex_tools::ShellCommandBackendConfig; -use codex_tools::ToolName; -pub struct ShellHandler; -pub struct ContainerExecHandler; -pub struct LocalShellHandler; +mod container_exec; +mod local_shell; +mod shell_command; +mod shell_handler; -#[derive(Clone, Copy, Debug, Eq, PartialEq)] -enum ShellCommandBackend { - Classic, - ZshFork, -} - -pub struct ShellCommandHandler { - backend: ShellCommandBackend, -} +pub use container_exec::ContainerExecHandler; +pub use local_shell::LocalShellHandler; +pub use shell_command::ShellCommandHandler; +pub use shell_handler::ShellHandler; fn shell_function_payload_command(payload: &ToolPayload) -> Option { let ToolPayload::Function { arguments } = payload else { @@ -100,333 +84,6 @@ struct RunExecLikeArgs { shell_runtime_backend: ShellRuntimeBackend, } -impl ShellHandler { - fn to_exec_params( - params: &ShellToolCallParams, - turn_context: &TurnContext, - thread_id: ThreadId, - ) -> ExecParams { - ExecParams { - command: params.command.clone(), - cwd: turn_context.resolve_path(params.workdir.clone()), - expiration: params.timeout_ms.into(), - capture_policy: ExecCapturePolicy::ShellTool, - env: create_env(&turn_context.shell_environment_policy, Some(thread_id)), - network: turn_context.network.clone(), - sandbox_permissions: params.sandbox_permissions.unwrap_or_default(), - windows_sandbox_level: turn_context.windows_sandbox_level, - windows_sandbox_private_desktop: turn_context - .config - .permissions - .windows_sandbox_private_desktop, - justification: params.justification.clone(), - arg0: None, - } - } -} - -impl ShellCommandHandler { - fn shell_runtime_backend(&self) -> ShellRuntimeBackend { - match self.backend { - ShellCommandBackend::Classic => ShellRuntimeBackend::ShellCommandClassic, - ShellCommandBackend::ZshFork => ShellRuntimeBackend::ShellCommandZshFork, - } - } - - fn resolve_use_login_shell( - login: Option, - allow_login_shell: bool, - ) -> Result { - if !allow_login_shell && login == Some(true) { - return Err(FunctionCallError::RespondToModel( - "login shell is disabled by config; omit `login` or set it to false.".to_string(), - )); - } - - Ok(login.unwrap_or(allow_login_shell)) - } - - fn base_command(shell: &Shell, command: &str, use_login_shell: bool) -> Vec { - shell.derive_exec_args(command, use_login_shell) - } - - fn to_exec_params( - params: &ShellCommandToolCallParams, - session: &crate::session::session::Session, - turn_context: &TurnContext, - thread_id: ThreadId, - allow_login_shell: bool, - ) -> Result { - let shell = session.user_shell(); - let use_login_shell = Self::resolve_use_login_shell(params.login, allow_login_shell)?; - let command = Self::base_command(shell.as_ref(), ¶ms.command, use_login_shell); - - Ok(ExecParams { - command, - cwd: turn_context.resolve_path(params.workdir.clone()), - expiration: params.timeout_ms.into(), - capture_policy: ExecCapturePolicy::ShellTool, - env: create_env(&turn_context.shell_environment_policy, Some(thread_id)), - network: turn_context.network.clone(), - sandbox_permissions: params.sandbox_permissions.unwrap_or_default(), - windows_sandbox_level: turn_context.windows_sandbox_level, - windows_sandbox_private_desktop: turn_context - .config - .permissions - .windows_sandbox_private_desktop, - justification: params.justification.clone(), - arg0: None, - }) - } -} - -impl From for ShellCommandHandler { - fn from(config: ShellCommandBackendConfig) -> Self { - let backend = match config { - ShellCommandBackendConfig::Classic => ShellCommandBackend::Classic, - ShellCommandBackendConfig::ZshFork => ShellCommandBackend::ZshFork, - }; - Self { backend } - } -} - -impl ToolHandler for ShellHandler { - type Output = FunctionToolOutput; - - fn tool_name(&self) -> ToolName { - ToolName::plain("shell") - } - - fn kind(&self) -> ToolKind { - ToolKind::Function - } - - fn matches_kind(&self, payload: &ToolPayload) -> bool { - matches!(payload, ToolPayload::Function { .. }) - } - - async fn is_mutating(&self, invocation: &ToolInvocation) -> bool { - let ToolPayload::Function { arguments } = &invocation.payload else { - return true; - }; - - serde_json::from_str::(arguments) - .map(|params| !is_known_safe_command(¶ms.command)) - .unwrap_or(true) - } - - fn pre_tool_use_payload(&self, invocation: &ToolInvocation) -> Option { - shell_function_pre_tool_use_payload(invocation) - } - - fn post_tool_use_payload( - &self, - invocation: &ToolInvocation, - result: &Self::Output, - ) -> Option { - shell_function_post_tool_use_payload(invocation, result) - } - - async fn handle(&self, invocation: ToolInvocation) -> Result { - let ToolInvocation { - session, - turn, - tracker, - call_id, - payload, - .. - } = invocation; - - let arguments = match payload { - ToolPayload::Function { arguments } => arguments, - _ => { - return Err(FunctionCallError::RespondToModel( - "unsupported payload for shell handler".to_string(), - )); - } - }; - - let cwd = resolve_workdir_base_path(&arguments, &turn.cwd)?; - let params: ShellToolCallParams = parse_arguments_with_base_path(&arguments, &cwd)?; - let prefix_rule = params.prefix_rule.clone(); - let exec_params = - ShellHandler::to_exec_params(¶ms, turn.as_ref(), session.conversation_id); - ShellHandler::run_exec_like(RunExecLikeArgs { - tool_name: "shell".to_string(), - exec_params, - hook_command: codex_shell_command::parse_command::shlex_join(¶ms.command), - additional_permissions: params.additional_permissions.clone(), - prefix_rule, - session, - turn, - tracker, - call_id, - freeform: false, - shell_runtime_backend: ShellRuntimeBackend::Generic, - }) - .await - } -} - -impl ToolHandler for ContainerExecHandler { - type Output = FunctionToolOutput; - - fn tool_name(&self) -> ToolName { - ToolName::plain("container.exec") - } - - fn kind(&self) -> ToolKind { - ToolKind::Function - } - - fn matches_kind(&self, payload: &ToolPayload) -> bool { - matches!(payload, ToolPayload::Function { .. }) - } - - async fn is_mutating(&self, invocation: &ToolInvocation) -> bool { - let ToolPayload::Function { arguments } = &invocation.payload else { - return true; - }; - - serde_json::from_str::(arguments) - .map(|params| !is_known_safe_command(¶ms.command)) - .unwrap_or(true) - } - - fn pre_tool_use_payload(&self, invocation: &ToolInvocation) -> Option { - shell_function_pre_tool_use_payload(invocation) - } - - fn post_tool_use_payload( - &self, - invocation: &ToolInvocation, - result: &Self::Output, - ) -> Option { - shell_function_post_tool_use_payload(invocation, result) - } - - async fn handle(&self, invocation: ToolInvocation) -> Result { - let ToolInvocation { - session, - turn, - tracker, - call_id, - payload, - .. - } = invocation; - - let arguments = match payload { - ToolPayload::Function { arguments } => arguments, - _ => { - return Err(FunctionCallError::RespondToModel( - "unsupported payload for container.exec handler".to_string(), - )); - } - }; - - let cwd = resolve_workdir_base_path(&arguments, &turn.cwd)?; - let params: ShellToolCallParams = parse_arguments_with_base_path(&arguments, &cwd)?; - let prefix_rule = params.prefix_rule.clone(); - let exec_params = - ShellHandler::to_exec_params(¶ms, turn.as_ref(), session.conversation_id); - ShellHandler::run_exec_like(RunExecLikeArgs { - tool_name: "container.exec".to_string(), - exec_params, - hook_command: codex_shell_command::parse_command::shlex_join(¶ms.command), - additional_permissions: params.additional_permissions.clone(), - prefix_rule, - session, - turn, - tracker, - call_id, - freeform: false, - shell_runtime_backend: ShellRuntimeBackend::Generic, - }) - .await - } -} - -impl ToolHandler for LocalShellHandler { - type Output = FunctionToolOutput; - - fn tool_name(&self) -> ToolName { - ToolName::plain("local_shell") - } - - fn kind(&self) -> ToolKind { - ToolKind::Function - } - - fn matches_kind(&self, payload: &ToolPayload) -> bool { - matches!(payload, ToolPayload::LocalShell { .. }) - } - - async fn is_mutating(&self, invocation: &ToolInvocation) -> bool { - let ToolPayload::LocalShell { params } = &invocation.payload else { - return true; - }; - - !is_known_safe_command(¶ms.command) - } - - fn pre_tool_use_payload(&self, invocation: &ToolInvocation) -> Option { - local_shell_payload_command(&invocation.payload).map(|command| PreToolUsePayload { - tool_name: HookToolName::bash(), - tool_input: serde_json::json!({ "command": command }), - }) - } - - fn post_tool_use_payload( - &self, - invocation: &ToolInvocation, - result: &Self::Output, - ) -> Option { - let tool_response = - result.post_tool_use_response(&invocation.call_id, &invocation.payload)?; - let command = local_shell_payload_command(&invocation.payload)?; - Some(PostToolUsePayload { - tool_name: HookToolName::bash(), - tool_use_id: invocation.call_id.clone(), - tool_input: serde_json::json!({ "command": command }), - tool_response, - }) - } - - async fn handle(&self, invocation: ToolInvocation) -> Result { - let ToolInvocation { - session, - turn, - tracker, - call_id, - payload, - .. - } = invocation; - - let ToolPayload::LocalShell { params } = payload else { - return Err(FunctionCallError::RespondToModel( - "unsupported payload for local_shell handler".to_string(), - )); - }; - - let exec_params = - ShellHandler::to_exec_params(¶ms, turn.as_ref(), session.conversation_id); - ShellHandler::run_exec_like(RunExecLikeArgs { - tool_name: "local_shell".to_string(), - exec_params, - hook_command: codex_shell_command::parse_command::shlex_join(¶ms.command), - additional_permissions: None, - prefix_rule: None, - session, - turn, - tracker, - call_id, - freeform: false, - shell_runtime_backend: ShellRuntimeBackend::Generic, - }) - .await - } -} - fn shell_function_pre_tool_use_payload(invocation: &ToolInvocation) -> Option { shell_function_payload_command(&invocation.payload).map(|command| PreToolUsePayload { tool_name: HookToolName::bash(), @@ -448,316 +105,203 @@ fn shell_function_post_tool_use_payload( }) } -impl ToolHandler for ShellCommandHandler { - type Output = FunctionToolOutput; - - fn tool_name(&self) -> ToolName { - ToolName::plain("shell_command") - } - - fn kind(&self) -> ToolKind { - ToolKind::Function - } - - fn matches_kind(&self, payload: &ToolPayload) -> bool { - matches!(payload, ToolPayload::Function { .. }) - } - - async fn is_mutating(&self, invocation: &ToolInvocation) -> bool { - let ToolPayload::Function { arguments } = &invocation.payload else { - return true; - }; - - serde_json::from_str::(arguments) - .map(|params| { - let use_login_shell = match Self::resolve_use_login_shell( - params.login, - invocation.turn.tools_config.allow_login_shell, - ) { - Ok(use_login_shell) => use_login_shell, - Err(_) => return true, - }; - let shell = invocation.session.user_shell(); - let command = Self::base_command(shell.as_ref(), ¶ms.command, use_login_shell); - !is_known_safe_command(&command) - }) - .unwrap_or(true) - } +async fn run_exec_like(args: RunExecLikeArgs) -> Result { + let RunExecLikeArgs { + tool_name, + exec_params, + hook_command, + additional_permissions, + prefix_rule, + session, + turn, + tracker, + call_id, + freeform, + shell_runtime_backend, + } = args; + + let mut exec_params = exec_params; + let Some(turn_environment) = turn.environments.primary() else { + return Err(FunctionCallError::RespondToModel( + "shell is unavailable in this session".to_string(), + )); + }; + let fs = turn_environment.environment.get_filesystem(); - fn pre_tool_use_payload(&self, invocation: &ToolInvocation) -> Option { - shell_command_payload_command(&invocation.payload).map(|command| PreToolUsePayload { - tool_name: HookToolName::bash(), - tool_input: serde_json::json!({ "command": command }), - }) + let dependency_env = session.dependency_env().await; + if !dependency_env.is_empty() { + exec_params.env.extend(dependency_env.clone()); } - fn post_tool_use_payload( - &self, - invocation: &ToolInvocation, - result: &Self::Output, - ) -> Option { - let tool_response = - result.post_tool_use_response(&invocation.call_id, &invocation.payload)?; - let command = shell_command_payload_command(&invocation.payload)?; - Some(PostToolUsePayload { - tool_name: HookToolName::bash(), - tool_use_id: invocation.call_id.clone(), - tool_input: serde_json::json!({ "command": command }), - tool_response, - }) + let mut explicit_env_overrides = turn.shell_environment_policy.r#set.clone(); + for key in dependency_env.keys() { + if let Some(value) = exec_params.env.get(key) { + explicit_env_overrides.insert(key.clone(), value.clone()); + } } - async fn handle(&self, invocation: ToolInvocation) -> Result { - let ToolInvocation { - session, - turn, - tracker, - call_id, - payload, - .. - } = invocation; - - let ToolPayload::Function { arguments } = payload else { - return Err(FunctionCallError::RespondToModel(format!( - "unsupported payload for shell_command handler: {}", - self.tool_name().display() - ))); - }; - - let cwd = resolve_workdir_base_path(&arguments, &turn.cwd)?; - let params: ShellCommandToolCallParams = parse_arguments_with_base_path(&arguments, &cwd)?; - let workdir = turn.resolve_path(params.workdir.clone()); - maybe_emit_implicit_skill_invocation( - session.as_ref(), - turn.as_ref(), - ¶ms.command, - &workdir, + let exec_permission_approvals_enabled = + session.features().enabled(Feature::ExecPermissionApprovals); + let requested_additional_permissions = additional_permissions.clone(); + let effective_additional_permissions = apply_granted_turn_permissions( + session.as_ref(), + turn.cwd.as_path(), + exec_params.sandbox_permissions, + additional_permissions, + ) + .await; + let additional_permissions_allowed = exec_permission_approvals_enabled + || (session.features().enabled(Feature::RequestPermissionsTool) + && effective_additional_permissions.permissions_preapproved); + let normalized_additional_permissions = implicit_granted_permissions( + exec_params.sandbox_permissions, + requested_additional_permissions.as_ref(), + &effective_additional_permissions, + ) + .map_or_else( + || { + normalize_and_validate_additional_permissions( + additional_permissions_allowed, + turn.approval_policy.value(), + effective_additional_permissions.sandbox_permissions, + effective_additional_permissions.additional_permissions, + effective_additional_permissions.permissions_preapproved, + &exec_params.cwd, + ) + }, + |permissions| Ok(Some(permissions)), + ) + .map_err(FunctionCallError::RespondToModel)?; + + // Approval policy guard for explicit escalation in non-OnRequest modes. + // Sticky turn permissions have already been approved, so they should + // continue through the normal exec approval flow for the command. + if effective_additional_permissions + .sandbox_permissions + .requests_sandbox_override() + && !effective_additional_permissions.permissions_preapproved + && !matches!( + turn.approval_policy.value(), + codex_protocol::protocol::AskForApproval::OnRequest ) - .await; - let prefix_rule = params.prefix_rule.clone(); - let exec_params = Self::to_exec_params( - ¶ms, - session.as_ref(), - turn.as_ref(), - session.conversation_id, - turn.tools_config.allow_login_shell, - )?; - ShellHandler::run_exec_like(RunExecLikeArgs { - tool_name: self.tool_name().display(), - exec_params, - hook_command: params.command, - additional_permissions: params.additional_permissions.clone(), + { + let approval_policy = turn.approval_policy.value(); + return Err(FunctionCallError::RespondToModel(format!( + "approval policy is {approval_policy:?}; reject command — you should not ask for escalated permissions if the approval policy is {approval_policy:?}" + ))); + } + + // Intercept apply_patch if present. + if let Some(output) = intercept_apply_patch( + &exec_params.command, + &exec_params.cwd, + fs.as_ref(), + session.clone(), + turn.clone(), + Some(&tracker), + &call_id, + tool_name.as_str(), + ) + .await? + { + return Ok(output); + } + + let source = ExecCommandSource::Agent; + let emitter = ToolEmitter::shell( + exec_params.command.clone(), + exec_params.cwd.clone(), + source, + freeform, + ); + let event_ctx = ToolEventCtx::new( + session.as_ref(), + turn.as_ref(), + &call_id, + /*turn_diff_tracker*/ None, + ); + emitter.begin(event_ctx).await; + + let file_system_sandbox_policy = turn.file_system_sandbox_policy(); + let exec_approval_requirement = session + .services + .exec_policy + .create_exec_approval_requirement_for_command(ExecApprovalRequest { + command: &exec_params.command, + approval_policy: turn.approval_policy.value(), + permission_profile: turn.permission_profile(), + file_system_sandbox_policy: &file_system_sandbox_policy, + sandbox_cwd: turn.cwd.as_path(), + sandbox_permissions: if effective_additional_permissions.permissions_preapproved { + codex_protocol::models::SandboxPermissions::UseDefault + } else { + effective_additional_permissions.sandbox_permissions + }, prefix_rule, - session, - turn, - tracker, - call_id, - freeform: true, - shell_runtime_backend: self.shell_runtime_backend(), }) - .await - } -} - -impl ShellHandler { - async fn run_exec_like(args: RunExecLikeArgs) -> Result { - let RunExecLikeArgs { - tool_name, - exec_params, - hook_command, - additional_permissions, - prefix_rule, - session, - turn, - tracker, - call_id, - freeform, - shell_runtime_backend, - } = args; - - let mut exec_params = exec_params; - let Some(turn_environment) = turn.environments.primary() else { - return Err(FunctionCallError::RespondToModel( - "shell is unavailable in this session".to_string(), - )); - }; - let fs = turn_environment.environment.get_filesystem(); - - let dependency_env = session.dependency_env().await; - if !dependency_env.is_empty() { - exec_params.env.extend(dependency_env.clone()); - } - - let mut explicit_env_overrides = turn.shell_environment_policy.r#set.clone(); - for key in dependency_env.keys() { - if let Some(value) = exec_params.env.get(key) { - explicit_env_overrides.insert(key.clone(), value.clone()); - } - } - - let exec_permission_approvals_enabled = - session.features().enabled(Feature::ExecPermissionApprovals); - let requested_additional_permissions = additional_permissions.clone(); - let effective_additional_permissions = apply_granted_turn_permissions( - session.as_ref(), - turn.cwd.as_path(), - exec_params.sandbox_permissions, - additional_permissions, - ) .await; - let additional_permissions_allowed = exec_permission_approvals_enabled - || (session.features().enabled(Feature::RequestPermissionsTool) - && effective_additional_permissions.permissions_preapproved); - let normalized_additional_permissions = implicit_granted_permissions( - exec_params.sandbox_permissions, - requested_additional_permissions.as_ref(), - &effective_additional_permissions, - ) - .map_or_else( - || { - normalize_and_validate_additional_permissions( - additional_permissions_allowed, - turn.approval_policy.value(), - effective_additional_permissions.sandbox_permissions, - effective_additional_permissions.additional_permissions, - effective_additional_permissions.permissions_preapproved, - &exec_params.cwd, - ) - }, - |permissions| Ok(Some(permissions)), - ) - .map_err(FunctionCallError::RespondToModel)?; - // Approval policy guard for explicit escalation in non-OnRequest modes. - // Sticky turn permissions have already been approved, so they should - // continue through the normal exec approval flow for the command. - if effective_additional_permissions - .sandbox_permissions - .requests_sandbox_override() - && !effective_additional_permissions.permissions_preapproved - && !matches!( - turn.approval_policy.value(), - codex_protocol::protocol::AskForApproval::OnRequest - ) - { - let approval_policy = turn.approval_policy.value(); - return Err(FunctionCallError::RespondToModel(format!( - "approval policy is {approval_policy:?}; reject command — you should not ask for escalated permissions if the approval policy is {approval_policy:?}" - ))); + let req = ShellRequest { + command: exec_params.command.clone(), + hook_command, + cwd: exec_params.cwd.clone(), + timeout_ms: exec_params.expiration.timeout_ms(), + env: exec_params.env.clone(), + explicit_env_overrides, + network: exec_params.network.clone(), + sandbox_permissions: effective_additional_permissions.sandbox_permissions, + additional_permissions: normalized_additional_permissions, + #[cfg(unix)] + additional_permissions_preapproved: effective_additional_permissions + .permissions_preapproved, + justification: exec_params.justification.clone(), + exec_approval_requirement, + }; + let mut orchestrator = ToolOrchestrator::new(); + let mut runtime = { + use ShellRuntimeBackend::*; + match shell_runtime_backend { + Generic => ShellRuntime::new(), + backend @ (ShellCommandClassic | ShellCommandZshFork) => { + ShellRuntime::for_shell_command(backend) + } } - - // Intercept apply_patch if present. - if let Some(output) = intercept_apply_patch( - &exec_params.command, - &exec_params.cwd, - fs.as_ref(), - session.clone(), - turn.clone(), - Some(&tracker), - &call_id, - tool_name.as_str(), + }; + let tool_ctx = ToolCtx { + session: session.clone(), + turn: turn.clone(), + call_id: call_id.clone(), + tool_name, + }; + let out = orchestrator + .run( + &mut runtime, + &req, + &tool_ctx, + &turn, + turn.approval_policy.value(), ) - .await? - { - return Ok(output); - } - - let source = ExecCommandSource::Agent; - let emitter = ToolEmitter::shell( - exec_params.command.clone(), - exec_params.cwd.clone(), - source, - freeform, - ); - let event_ctx = ToolEventCtx::new( - session.as_ref(), - turn.as_ref(), - &call_id, - /*turn_diff_tracker*/ None, - ); - emitter.begin(event_ctx).await; - - let file_system_sandbox_policy = turn.file_system_sandbox_policy(); - let exec_approval_requirement = session - .services - .exec_policy - .create_exec_approval_requirement_for_command(ExecApprovalRequest { - command: &exec_params.command, - approval_policy: turn.approval_policy.value(), - permission_profile: turn.permission_profile(), - file_system_sandbox_policy: &file_system_sandbox_policy, - sandbox_cwd: turn.cwd.as_path(), - sandbox_permissions: if effective_additional_permissions.permissions_preapproved { - codex_protocol::models::SandboxPermissions::UseDefault - } else { - effective_additional_permissions.sandbox_permissions - }, - prefix_rule, - }) - .await; - - let req = ShellRequest { - command: exec_params.command.clone(), - hook_command, - cwd: exec_params.cwd.clone(), - timeout_ms: exec_params.expiration.timeout_ms(), - env: exec_params.env.clone(), - explicit_env_overrides, - network: exec_params.network.clone(), - sandbox_permissions: effective_additional_permissions.sandbox_permissions, - additional_permissions: normalized_additional_permissions, - #[cfg(unix)] - additional_permissions_preapproved: effective_additional_permissions - .permissions_preapproved, - justification: exec_params.justification.clone(), - exec_approval_requirement, - }; - let mut orchestrator = ToolOrchestrator::new(); - let mut runtime = { - use ShellRuntimeBackend::*; - match shell_runtime_backend { - Generic => ShellRuntime::new(), - backend @ (ShellCommandClassic | ShellCommandZshFork) => { - ShellRuntime::for_shell_command(backend) - } - } - }; - let tool_ctx = ToolCtx { - session: session.clone(), - turn: turn.clone(), - call_id: call_id.clone(), - tool_name, - }; - let out = orchestrator - .run( - &mut runtime, - &req, - &tool_ctx, - &turn, - turn.approval_policy.value(), - ) - .await - .map(|result| result.output); - let event_ctx = ToolEventCtx::new( - session.as_ref(), - turn.as_ref(), - &call_id, - /*turn_diff_tracker*/ None, - ); - let post_tool_use_response = out - .as_ref() - .ok() - .map(|output| crate::tools::format_exec_output_str(output, turn.truncation_policy)) - .map(JsonValue::String); - let content = emitter.finish(event_ctx, out).await?; - Ok(FunctionToolOutput { - body: vec![ - codex_protocol::models::FunctionCallOutputContentItem::InputText { text: content }, - ], - success: Some(true), - post_tool_use_response, - }) - } + .await + .map(|result| result.output); + let event_ctx = ToolEventCtx::new( + session.as_ref(), + turn.as_ref(), + &call_id, + /*turn_diff_tracker*/ None, + ); + let post_tool_use_response = out + .as_ref() + .ok() + .map(|output| crate::tools::format_exec_output_str(output, turn.truncation_policy)) + .map(JsonValue::String); + let content = emitter.finish(event_ctx, out).await?; + Ok(FunctionToolOutput { + body: vec![ + codex_protocol::models::FunctionCallOutputContentItem::InputText { text: content }, + ], + success: Some(true), + post_tool_use_response, + }) } #[cfg(test)] diff --git a/codex-rs/core/src/tools/handlers/shell/container_exec.rs b/codex-rs/core/src/tools/handlers/shell/container_exec.rs new file mode 100644 index 000000000000..70bf56fb4d3d --- /dev/null +++ b/codex-rs/core/src/tools/handlers/shell/container_exec.rs @@ -0,0 +1,101 @@ +use codex_protocol::models::ShellToolCallParams; +use codex_shell_command::is_safe_command::is_known_safe_command; +use codex_tools::ToolName; + +use crate::function_tool::FunctionCallError; +use crate::tools::context::FunctionToolOutput; +use crate::tools::context::ToolInvocation; +use crate::tools::context::ToolPayload; +use crate::tools::handlers::parse_arguments_with_base_path; +use crate::tools::handlers::resolve_workdir_base_path; +use crate::tools::registry::PostToolUsePayload; +use crate::tools::registry::PreToolUsePayload; +use crate::tools::registry::ToolHandler; +use crate::tools::registry::ToolKind; +use crate::tools::runtimes::shell::ShellRuntimeBackend; + +use super::RunExecLikeArgs; +use super::run_exec_like; +use super::shell_function_post_tool_use_payload; +use super::shell_function_pre_tool_use_payload; +use super::shell_handler::ShellHandler; + +pub struct ContainerExecHandler; + +impl ToolHandler for ContainerExecHandler { + type Output = FunctionToolOutput; + + fn tool_name(&self) -> ToolName { + ToolName::plain("container.exec") + } + + fn kind(&self) -> ToolKind { + ToolKind::Function + } + + fn matches_kind(&self, payload: &ToolPayload) -> bool { + matches!(payload, ToolPayload::Function { .. }) + } + + async fn is_mutating(&self, invocation: &ToolInvocation) -> bool { + let ToolPayload::Function { arguments } = &invocation.payload else { + return true; + }; + + serde_json::from_str::(arguments) + .map(|params| !is_known_safe_command(¶ms.command)) + .unwrap_or(true) + } + + fn pre_tool_use_payload(&self, invocation: &ToolInvocation) -> Option { + shell_function_pre_tool_use_payload(invocation) + } + + fn post_tool_use_payload( + &self, + invocation: &ToolInvocation, + result: &Self::Output, + ) -> Option { + shell_function_post_tool_use_payload(invocation, result) + } + + async fn handle(&self, invocation: ToolInvocation) -> Result { + let ToolInvocation { + session, + turn, + tracker, + call_id, + payload, + .. + } = invocation; + + let arguments = match payload { + ToolPayload::Function { arguments } => arguments, + _ => { + return Err(FunctionCallError::RespondToModel( + "unsupported payload for container.exec handler".to_string(), + )); + } + }; + + let cwd = resolve_workdir_base_path(&arguments, &turn.cwd)?; + let params: ShellToolCallParams = parse_arguments_with_base_path(&arguments, &cwd)?; + let prefix_rule = params.prefix_rule.clone(); + let exec_params = + ShellHandler::to_exec_params(¶ms, turn.as_ref(), session.conversation_id); + run_exec_like(RunExecLikeArgs { + tool_name: "container.exec".to_string(), + exec_params, + hook_command: codex_shell_command::parse_command::shlex_join(¶ms.command), + additional_permissions: params.additional_permissions.clone(), + prefix_rule, + session, + turn, + tracker, + call_id, + freeform: false, + shell_runtime_backend: ShellRuntimeBackend::Generic, + }) + .await + } +} diff --git a/codex-rs/core/src/tools/handlers/shell/local_shell.rs b/codex-rs/core/src/tools/handlers/shell/local_shell.rs new file mode 100644 index 000000000000..bdb70e936842 --- /dev/null +++ b/codex-rs/core/src/tools/handlers/shell/local_shell.rs @@ -0,0 +1,102 @@ +use codex_shell_command::is_safe_command::is_known_safe_command; +use codex_tools::ToolName; + +use crate::function_tool::FunctionCallError; +use crate::tools::context::FunctionToolOutput; +use crate::tools::context::ToolInvocation; +use crate::tools::context::ToolOutput; +use crate::tools::context::ToolPayload; +use crate::tools::hook_names::HookToolName; +use crate::tools::registry::PostToolUsePayload; +use crate::tools::registry::PreToolUsePayload; +use crate::tools::registry::ToolHandler; +use crate::tools::registry::ToolKind; +use crate::tools::runtimes::shell::ShellRuntimeBackend; + +use super::RunExecLikeArgs; +use super::local_shell_payload_command; +use super::run_exec_like; +use super::shell_handler::ShellHandler; + +pub struct LocalShellHandler; + +impl ToolHandler for LocalShellHandler { + type Output = FunctionToolOutput; + + fn tool_name(&self) -> ToolName { + ToolName::plain("local_shell") + } + + fn kind(&self) -> ToolKind { + ToolKind::Function + } + + fn matches_kind(&self, payload: &ToolPayload) -> bool { + matches!(payload, ToolPayload::LocalShell { .. }) + } + + async fn is_mutating(&self, invocation: &ToolInvocation) -> bool { + let ToolPayload::LocalShell { params } = &invocation.payload else { + return true; + }; + + !is_known_safe_command(¶ms.command) + } + + fn pre_tool_use_payload(&self, invocation: &ToolInvocation) -> Option { + local_shell_payload_command(&invocation.payload).map(|command| PreToolUsePayload { + tool_name: HookToolName::bash(), + tool_input: serde_json::json!({ "command": command }), + }) + } + + fn post_tool_use_payload( + &self, + invocation: &ToolInvocation, + result: &Self::Output, + ) -> Option { + let tool_response = + result.post_tool_use_response(&invocation.call_id, &invocation.payload)?; + let command = local_shell_payload_command(&invocation.payload)?; + Some(PostToolUsePayload { + tool_name: HookToolName::bash(), + tool_use_id: invocation.call_id.clone(), + tool_input: serde_json::json!({ "command": command }), + tool_response, + }) + } + + async fn handle(&self, invocation: ToolInvocation) -> Result { + let ToolInvocation { + session, + turn, + tracker, + call_id, + payload, + .. + } = invocation; + + let ToolPayload::LocalShell { params } = payload else { + return Err(FunctionCallError::RespondToModel( + "unsupported payload for local_shell handler".to_string(), + )); + }; + + let exec_params = + ShellHandler::to_exec_params(¶ms, turn.as_ref(), session.conversation_id); + run_exec_like(RunExecLikeArgs { + tool_name: "local_shell".to_string(), + exec_params, + hook_command: codex_shell_command::parse_command::shlex_join(¶ms.command), + additional_permissions: None, + prefix_rule: None, + session, + turn, + tracker, + call_id, + freeform: false, + shell_runtime_backend: ShellRuntimeBackend::Generic, + }) + .await + } +} diff --git a/codex-rs/core/src/tools/handlers/shell/shell_command.rs b/codex-rs/core/src/tools/handlers/shell/shell_command.rs new file mode 100644 index 000000000000..69f965b51e09 --- /dev/null +++ b/codex-rs/core/src/tools/handlers/shell/shell_command.rs @@ -0,0 +1,215 @@ +use codex_protocol::ThreadId; +use codex_protocol::models::ShellCommandToolCallParams; +use codex_shell_command::is_safe_command::is_known_safe_command; +use codex_tools::ShellCommandBackendConfig; +use codex_tools::ToolName; + +use crate::exec::ExecCapturePolicy; +use crate::exec::ExecParams; +use crate::exec_env::create_env; +use crate::function_tool::FunctionCallError; +use crate::maybe_emit_implicit_skill_invocation; +use crate::session::turn_context::TurnContext; +use crate::shell::Shell; +use crate::tools::context::FunctionToolOutput; +use crate::tools::context::ToolInvocation; +use crate::tools::context::ToolOutput; +use crate::tools::context::ToolPayload; +use crate::tools::handlers::parse_arguments_with_base_path; +use crate::tools::handlers::resolve_workdir_base_path; +use crate::tools::hook_names::HookToolName; +use crate::tools::registry::PostToolUsePayload; +use crate::tools::registry::PreToolUsePayload; +use crate::tools::registry::ToolHandler; +use crate::tools::registry::ToolKind; +use crate::tools::runtimes::shell::ShellRuntimeBackend; + +use super::RunExecLikeArgs; +use super::run_exec_like; +use super::shell_command_payload_command; + +#[derive(Clone, Copy, Debug, Eq, PartialEq)] +enum ShellCommandBackend { + Classic, + ZshFork, +} + +pub struct ShellCommandHandler { + backend: ShellCommandBackend, +} + +impl ShellCommandHandler { + fn shell_runtime_backend(&self) -> ShellRuntimeBackend { + match self.backend { + ShellCommandBackend::Classic => ShellRuntimeBackend::ShellCommandClassic, + ShellCommandBackend::ZshFork => ShellRuntimeBackend::ShellCommandZshFork, + } + } + + pub(super) fn resolve_use_login_shell( + login: Option, + allow_login_shell: bool, + ) -> Result { + if !allow_login_shell && login == Some(true) { + return Err(FunctionCallError::RespondToModel( + "login shell is disabled by config; omit `login` or set it to false.".to_string(), + )); + } + + Ok(login.unwrap_or(allow_login_shell)) + } + + pub(super) fn base_command(shell: &Shell, command: &str, use_login_shell: bool) -> Vec { + shell.derive_exec_args(command, use_login_shell) + } + + pub(super) fn to_exec_params( + params: &ShellCommandToolCallParams, + session: &crate::session::session::Session, + turn_context: &TurnContext, + thread_id: ThreadId, + allow_login_shell: bool, + ) -> Result { + let shell = session.user_shell(); + let use_login_shell = Self::resolve_use_login_shell(params.login, allow_login_shell)?; + let command = Self::base_command(shell.as_ref(), ¶ms.command, use_login_shell); + + Ok(ExecParams { + command, + cwd: turn_context.resolve_path(params.workdir.clone()), + expiration: params.timeout_ms.into(), + capture_policy: ExecCapturePolicy::ShellTool, + env: create_env(&turn_context.shell_environment_policy, Some(thread_id)), + network: turn_context.network.clone(), + sandbox_permissions: params.sandbox_permissions.unwrap_or_default(), + windows_sandbox_level: turn_context.windows_sandbox_level, + windows_sandbox_private_desktop: turn_context + .config + .permissions + .windows_sandbox_private_desktop, + justification: params.justification.clone(), + arg0: None, + }) + } +} + +impl From for ShellCommandHandler { + fn from(config: ShellCommandBackendConfig) -> Self { + let backend = match config { + ShellCommandBackendConfig::Classic => ShellCommandBackend::Classic, + ShellCommandBackendConfig::ZshFork => ShellCommandBackend::ZshFork, + }; + Self { backend } + } +} + +impl ToolHandler for ShellCommandHandler { + type Output = FunctionToolOutput; + + fn tool_name(&self) -> ToolName { + ToolName::plain("shell_command") + } + + fn kind(&self) -> ToolKind { + ToolKind::Function + } + + fn matches_kind(&self, payload: &ToolPayload) -> bool { + matches!(payload, ToolPayload::Function { .. }) + } + + async fn is_mutating(&self, invocation: &ToolInvocation) -> bool { + let ToolPayload::Function { arguments } = &invocation.payload else { + return true; + }; + + serde_json::from_str::(arguments) + .map(|params| { + let use_login_shell = match Self::resolve_use_login_shell( + params.login, + invocation.turn.tools_config.allow_login_shell, + ) { + Ok(use_login_shell) => use_login_shell, + Err(_) => return true, + }; + let shell = invocation.session.user_shell(); + let command = Self::base_command(shell.as_ref(), ¶ms.command, use_login_shell); + !is_known_safe_command(&command) + }) + .unwrap_or(true) + } + + fn pre_tool_use_payload(&self, invocation: &ToolInvocation) -> Option { + shell_command_payload_command(&invocation.payload).map(|command| PreToolUsePayload { + tool_name: HookToolName::bash(), + tool_input: serde_json::json!({ "command": command }), + }) + } + + fn post_tool_use_payload( + &self, + invocation: &ToolInvocation, + result: &Self::Output, + ) -> Option { + let tool_response = + result.post_tool_use_response(&invocation.call_id, &invocation.payload)?; + let command = shell_command_payload_command(&invocation.payload)?; + Some(PostToolUsePayload { + tool_name: HookToolName::bash(), + tool_use_id: invocation.call_id.clone(), + tool_input: serde_json::json!({ "command": command }), + tool_response, + }) + } + + async fn handle(&self, invocation: ToolInvocation) -> Result { + let ToolInvocation { + session, + turn, + tracker, + call_id, + payload, + .. + } = invocation; + + let ToolPayload::Function { arguments } = payload else { + return Err(FunctionCallError::RespondToModel(format!( + "unsupported payload for shell_command handler: {}", + self.tool_name().display() + ))); + }; + + let cwd = resolve_workdir_base_path(&arguments, &turn.cwd)?; + let params: ShellCommandToolCallParams = parse_arguments_with_base_path(&arguments, &cwd)?; + let workdir = turn.resolve_path(params.workdir.clone()); + maybe_emit_implicit_skill_invocation( + session.as_ref(), + turn.as_ref(), + ¶ms.command, + &workdir, + ) + .await; + let prefix_rule = params.prefix_rule.clone(); + let exec_params = Self::to_exec_params( + ¶ms, + session.as_ref(), + turn.as_ref(), + session.conversation_id, + turn.tools_config.allow_login_shell, + )?; + run_exec_like(RunExecLikeArgs { + tool_name: self.tool_name().display(), + exec_params, + hook_command: params.command, + additional_permissions: params.additional_permissions.clone(), + prefix_rule, + session, + turn, + tracker, + call_id, + freeform: true, + shell_runtime_backend: self.shell_runtime_backend(), + }) + .await + } +} diff --git a/codex-rs/core/src/tools/handlers/shell/shell_handler.rs b/codex-rs/core/src/tools/handlers/shell/shell_handler.rs new file mode 100644 index 000000000000..30220d3db0da --- /dev/null +++ b/codex-rs/core/src/tools/handlers/shell/shell_handler.rs @@ -0,0 +1,130 @@ +use codex_protocol::ThreadId; +use codex_protocol::models::ShellToolCallParams; +use codex_shell_command::is_safe_command::is_known_safe_command; +use codex_tools::ToolName; + +use crate::exec::ExecCapturePolicy; +use crate::exec::ExecParams; +use crate::exec_env::create_env; +use crate::function_tool::FunctionCallError; +use crate::session::turn_context::TurnContext; +use crate::tools::context::FunctionToolOutput; +use crate::tools::context::ToolInvocation; +use crate::tools::context::ToolPayload; +use crate::tools::handlers::parse_arguments_with_base_path; +use crate::tools::handlers::resolve_workdir_base_path; +use crate::tools::registry::PostToolUsePayload; +use crate::tools::registry::PreToolUsePayload; +use crate::tools::registry::ToolHandler; +use crate::tools::registry::ToolKind; +use crate::tools::runtimes::shell::ShellRuntimeBackend; + +use super::RunExecLikeArgs; +use super::run_exec_like; +use super::shell_function_post_tool_use_payload; +use super::shell_function_pre_tool_use_payload; + +pub struct ShellHandler; + +impl ShellHandler { + pub(super) fn to_exec_params( + params: &ShellToolCallParams, + turn_context: &TurnContext, + thread_id: ThreadId, + ) -> ExecParams { + ExecParams { + command: params.command.clone(), + cwd: turn_context.resolve_path(params.workdir.clone()), + expiration: params.timeout_ms.into(), + capture_policy: ExecCapturePolicy::ShellTool, + env: create_env(&turn_context.shell_environment_policy, Some(thread_id)), + network: turn_context.network.clone(), + sandbox_permissions: params.sandbox_permissions.unwrap_or_default(), + windows_sandbox_level: turn_context.windows_sandbox_level, + windows_sandbox_private_desktop: turn_context + .config + .permissions + .windows_sandbox_private_desktop, + justification: params.justification.clone(), + arg0: None, + } + } +} + +impl ToolHandler for ShellHandler { + type Output = FunctionToolOutput; + + fn tool_name(&self) -> ToolName { + ToolName::plain("shell") + } + + fn kind(&self) -> ToolKind { + ToolKind::Function + } + + fn matches_kind(&self, payload: &ToolPayload) -> bool { + matches!(payload, ToolPayload::Function { .. }) + } + + async fn is_mutating(&self, invocation: &ToolInvocation) -> bool { + let ToolPayload::Function { arguments } = &invocation.payload else { + return true; + }; + + serde_json::from_str::(arguments) + .map(|params| !is_known_safe_command(¶ms.command)) + .unwrap_or(true) + } + + fn pre_tool_use_payload(&self, invocation: &ToolInvocation) -> Option { + shell_function_pre_tool_use_payload(invocation) + } + + fn post_tool_use_payload( + &self, + invocation: &ToolInvocation, + result: &Self::Output, + ) -> Option { + shell_function_post_tool_use_payload(invocation, result) + } + + async fn handle(&self, invocation: ToolInvocation) -> Result { + let ToolInvocation { + session, + turn, + tracker, + call_id, + payload, + .. + } = invocation; + + let arguments = match payload { + ToolPayload::Function { arguments } => arguments, + _ => { + return Err(FunctionCallError::RespondToModel( + "unsupported payload for shell handler".to_string(), + )); + } + }; + + let cwd = resolve_workdir_base_path(&arguments, &turn.cwd)?; + let params: ShellToolCallParams = parse_arguments_with_base_path(&arguments, &cwd)?; + let prefix_rule = params.prefix_rule.clone(); + let exec_params = + ShellHandler::to_exec_params(¶ms, turn.as_ref(), session.conversation_id); + run_exec_like(RunExecLikeArgs { + tool_name: "shell".to_string(), + exec_params, + hook_command: codex_shell_command::parse_command::shlex_join(¶ms.command), + additional_permissions: params.additional_permissions.clone(), + prefix_rule, + session, + turn, + tracker, + call_id, + freeform: false, + shell_runtime_backend: ShellRuntimeBackend::Generic, + }) + .await + } +} diff --git a/codex-rs/core/src/tools/handlers/shell_tests.rs b/codex-rs/core/src/tools/handlers/shell_tests.rs index 8a32e5404b1d..a7e6dae35c7c 100644 --- a/codex-rs/core/src/tools/handlers/shell_tests.rs +++ b/codex-rs/core/src/tools/handlers/shell_tests.rs @@ -247,9 +247,7 @@ async fn shell_command_pre_tool_use_payload_uses_raw_command() { arguments: json!({ "command": "printf shell command" }).to_string(), }; let (session, turn) = make_session_and_context().await; - let handler = ShellCommandHandler { - backend: super::ShellCommandBackend::Classic, - }; + let handler = ShellCommandHandler::from(codex_tools::ShellCommandBackendConfig::Classic); assert_eq!( handler.pre_tool_use_payload(&ToolInvocation { @@ -279,9 +277,7 @@ async fn build_post_tool_use_payload_uses_tool_output_wire_value() { success: Some(true), post_tool_use_response: Some(json!("shell output")), }; - let handler = ShellCommandHandler { - backend: super::ShellCommandBackend::Classic, - }; + let handler = ShellCommandHandler::from(codex_tools::ShellCommandBackendConfig::Classic); let (session, turn) = make_session_and_context().await; let invocation = ToolInvocation { session: session.into(), diff --git a/codex-rs/core/src/tools/handlers/unified_exec.rs b/codex-rs/core/src/tools/handlers/unified_exec.rs index c257240a4d74..80e85ccd474a 100644 --- a/codex-rs/core/src/tools/handlers/unified_exec.rs +++ b/codex-rs/core/src/tools/handlers/unified_exec.rs @@ -1,5 +1,3 @@ -use crate::function_tool::FunctionCallError; -use crate::maybe_emit_implicit_skill_invocation; use crate::sandboxing::SandboxPermissions; use crate::shell::Shell; use crate::shell::get_shell_by_model_provided_path; @@ -7,42 +5,24 @@ use crate::tools::context::ExecCommandToolOutput; use crate::tools::context::ToolInvocation; use crate::tools::context::ToolOutput; use crate::tools::context::ToolPayload; -use crate::tools::handlers::apply_granted_turn_permissions; -use crate::tools::handlers::apply_patch::intercept_apply_patch; -use crate::tools::handlers::implicit_granted_permissions; -use crate::tools::handlers::normalize_and_validate_additional_permissions; -use crate::tools::handlers::parse_arguments; -use crate::tools::handlers::parse_arguments_with_base_path; -use crate::tools::handlers::resolve_tool_environment; use crate::tools::hook_names::HookToolName; use crate::tools::registry::PostToolUsePayload; -use crate::tools::registry::PreToolUsePayload; -use crate::tools::registry::ToolHandler; -use crate::tools::registry::ToolKind; -use crate::unified_exec::ExecCommandRequest; -use crate::unified_exec::UnifiedExecContext; -use crate::unified_exec::UnifiedExecError; -use crate::unified_exec::UnifiedExecProcessManager; -use crate::unified_exec::WriteStdinRequest; -use crate::unified_exec::generate_chunk_id; use crate::unified_exec::resolve_max_tokens; -use codex_features::Feature; -use codex_otel::SessionTelemetry; -use codex_otel::TOOL_CALL_UNIFIED_EXEC_METRIC; use codex_protocol::models::AdditionalPermissionProfile; -use codex_protocol::protocol::EventMsg; -use codex_protocol::protocol::TerminalInteractionEvent; -use codex_shell_command::is_safe_command::is_known_safe_command; -use codex_tools::ToolName; use codex_tools::UnifiedExecShellMode; use codex_utils_output_truncation::TruncationPolicy; -use codex_utils_output_truncation::approx_token_count; use serde::Deserialize; use std::path::PathBuf; use std::sync::Arc; -pub struct ExecCommandHandler; -pub struct WriteStdinHandler; +#[cfg(test)] +use crate::tools::handlers::parse_arguments; + +mod exec_command; +mod write_stdin; + +pub use exec_command::ExecCommandHandler; +pub use write_stdin::WriteStdinHandler; #[derive(Debug, Deserialize)] pub(crate) struct ExecCommandArgs { @@ -79,18 +59,6 @@ struct ExecCommandEnvironmentArgs { workdir: Option, } -#[derive(Debug, Deserialize)] -struct WriteStdinArgs { - // The model is trained on `session_id`. - session_id: i32, - #[serde(default)] - chars: String, - #[serde(default = "default_write_stdin_yield_time_ms")] - yield_time_ms: u64, - #[serde(default)] - max_output_tokens: Option, -} - fn default_exec_yield_time_ms() -> u64 { 10_000 } @@ -110,343 +78,6 @@ fn effective_max_output_tokens( resolve_max_tokens(max_output_tokens).min(truncation_policy.token_budget()) } -impl ToolHandler for ExecCommandHandler { - type Output = ExecCommandToolOutput; - - fn tool_name(&self) -> ToolName { - ToolName::plain("exec_command") - } - - fn kind(&self) -> ToolKind { - ToolKind::Function - } - - fn matches_kind(&self, payload: &ToolPayload) -> bool { - matches!(payload, ToolPayload::Function { .. }) - } - - async fn is_mutating(&self, invocation: &ToolInvocation) -> bool { - let ToolPayload::Function { arguments } = &invocation.payload else { - tracing::error!( - "This should never happen, invocation payload is wrong: {:?}", - invocation.payload - ); - return true; - }; - - let Ok(params) = parse_arguments::(arguments) else { - return true; - }; - let command = match get_command( - ¶ms, - invocation.session.user_shell(), - &invocation.turn.tools_config.unified_exec_shell_mode, - invocation.turn.tools_config.allow_login_shell, - ) { - Ok(command) => command, - Err(_) => return true, - }; - !is_known_safe_command(&command) - } - - fn pre_tool_use_payload(&self, invocation: &ToolInvocation) -> Option { - let ToolPayload::Function { arguments } = &invocation.payload else { - return None; - }; - - parse_arguments::(arguments) - .ok() - .map(|args| PreToolUsePayload { - tool_name: HookToolName::bash(), - tool_input: serde_json::json!({ "command": args.cmd }), - }) - } - - fn post_tool_use_payload( - &self, - invocation: &ToolInvocation, - result: &Self::Output, - ) -> Option { - post_unified_exec_tool_use_payload(invocation, result) - } - - async fn handle(&self, invocation: ToolInvocation) -> Result { - let ToolInvocation { - session, - turn, - tracker, - call_id, - payload, - .. - } = invocation; - - let arguments = match payload { - ToolPayload::Function { arguments } => arguments, - _ => { - return Err(FunctionCallError::RespondToModel( - "exec_command handler received unsupported payload".to_string(), - )); - } - }; - - let manager: &UnifiedExecProcessManager = &session.services.unified_exec_manager; - let context = UnifiedExecContext::new(session.clone(), turn.clone(), call_id.clone()); - let environment_args: ExecCommandEnvironmentArgs = parse_arguments(&arguments)?; - let Some(turn_environment) = - resolve_tool_environment(turn.as_ref(), environment_args.environment_id.as_deref())? - else { - return Err(FunctionCallError::RespondToModel( - "unified exec is unavailable in this session".to_string(), - )); - }; - let cwd = environment_args - .workdir - .as_deref() - .filter(|workdir| !workdir.is_empty()) - .map_or_else( - || turn_environment.cwd.clone(), - |workdir| turn_environment.cwd.join(workdir), - ); - let environment = Arc::clone(&turn_environment.environment); - let fs = environment.get_filesystem(); - let args: ExecCommandArgs = parse_arguments_with_base_path(&arguments, &cwd)?; - let hook_command = args.cmd.clone(); - maybe_emit_implicit_skill_invocation( - session.as_ref(), - context.turn.as_ref(), - &hook_command, - &cwd, - ) - .await; - let process_id = manager.allocate_process_id().await; - let command = get_command( - &args, - session.user_shell(), - &turn.tools_config.unified_exec_shell_mode, - turn.tools_config.allow_login_shell, - ) - .map_err(FunctionCallError::RespondToModel)?; - let command_for_display = codex_shell_command::parse_command::shlex_join(&command); - - let ExecCommandArgs { - tty, - yield_time_ms, - max_output_tokens, - sandbox_permissions, - additional_permissions, - justification, - prefix_rule, - .. - } = args; - let max_output_tokens = - effective_max_output_tokens(max_output_tokens, turn.truncation_policy); - - let exec_permission_approvals_enabled = - session.features().enabled(Feature::ExecPermissionApprovals); - let requested_additional_permissions = additional_permissions.clone(); - let effective_additional_permissions = apply_granted_turn_permissions( - context.session.as_ref(), - cwd.as_path(), - sandbox_permissions, - additional_permissions, - ) - .await; - let additional_permissions_allowed = exec_permission_approvals_enabled - || (session.features().enabled(Feature::RequestPermissionsTool) - && effective_additional_permissions.permissions_preapproved); - - // Sticky turn permissions have already been approved, so they should - // continue through the normal exec approval flow for the command. - if effective_additional_permissions - .sandbox_permissions - .requests_sandbox_override() - && !effective_additional_permissions.permissions_preapproved - && !matches!( - context.turn.approval_policy.value(), - codex_protocol::protocol::AskForApproval::OnRequest - ) - { - let approval_policy = context.turn.approval_policy.value(); - manager.release_process_id(process_id).await; - return Err(FunctionCallError::RespondToModel(format!( - "approval policy is {approval_policy:?}; reject command — you cannot ask for escalated permissions if the approval policy is {approval_policy:?}" - ))); - } - - let normalized_additional_permissions = match implicit_granted_permissions( - sandbox_permissions, - requested_additional_permissions.as_ref(), - &effective_additional_permissions, - ) - .map_or_else( - || { - normalize_and_validate_additional_permissions( - additional_permissions_allowed, - context.turn.approval_policy.value(), - effective_additional_permissions.sandbox_permissions, - effective_additional_permissions.additional_permissions, - effective_additional_permissions.permissions_preapproved, - &cwd, - ) - }, - |permissions| Ok(Some(permissions)), - ) { - Ok(normalized) => normalized, - Err(err) => { - manager.release_process_id(process_id).await; - return Err(FunctionCallError::RespondToModel(err)); - } - }; - - if let Some(output) = intercept_apply_patch( - &command, - &cwd, - fs.as_ref(), - context.session.clone(), - context.turn.clone(), - Some(&tracker), - &context.call_id, - "exec_command", - ) - .await? - { - manager.release_process_id(process_id).await; - return Ok(ExecCommandToolOutput { - event_call_id: String::new(), - chunk_id: String::new(), - wall_time: std::time::Duration::ZERO, - raw_output: output.into_text().into_bytes(), - max_output_tokens: Some(max_output_tokens), - process_id: None, - exit_code: None, - original_token_count: None, - hook_command: None, - }); - } - - emit_unified_exec_tty_metric(&turn.session_telemetry, tty); - match manager - .exec_command( - ExecCommandRequest { - command, - hook_command: hook_command.clone(), - process_id, - yield_time_ms, - max_output_tokens: Some(max_output_tokens), - cwd, - environment, - network: context.turn.network.clone(), - tty, - sandbox_permissions: effective_additional_permissions.sandbox_permissions, - additional_permissions: normalized_additional_permissions, - additional_permissions_preapproved: effective_additional_permissions - .permissions_preapproved, - justification, - prefix_rule, - }, - &context, - ) - .await - { - Ok(response) => Ok(response), - Err(UnifiedExecError::SandboxDenied { output, .. }) => { - let output_text = output.aggregated_output.text; - let original_token_count = approx_token_count(&output_text); - Ok(ExecCommandToolOutput { - event_call_id: context.call_id.clone(), - chunk_id: generate_chunk_id(), - wall_time: output.duration, - raw_output: output_text.into_bytes(), - max_output_tokens: Some(max_output_tokens), - // Sandbox denial is terminal, so there is no live - // process for write_stdin to resume. - process_id: None, - exit_code: Some(output.exit_code), - original_token_count: Some(original_token_count), - hook_command: Some(hook_command), - }) - } - Err(err) => Err(FunctionCallError::RespondToModel(format!( - "exec_command failed for `{command_for_display}`: {err:?}" - ))), - } - } -} - -impl ToolHandler for WriteStdinHandler { - type Output = ExecCommandToolOutput; - - fn tool_name(&self) -> ToolName { - ToolName::plain("write_stdin") - } - - fn kind(&self) -> ToolKind { - ToolKind::Function - } - - fn matches_kind(&self, payload: &ToolPayload) -> bool { - matches!(payload, ToolPayload::Function { .. }) - } - - async fn is_mutating(&self, _invocation: &ToolInvocation) -> bool { - true - } - - fn post_tool_use_payload( - &self, - invocation: &ToolInvocation, - result: &Self::Output, - ) -> Option { - post_unified_exec_tool_use_payload(invocation, result) - } - - async fn handle(&self, invocation: ToolInvocation) -> Result { - let ToolInvocation { - session, - turn, - payload, - .. - } = invocation; - - let arguments = match payload { - ToolPayload::Function { arguments } => arguments, - _ => { - return Err(FunctionCallError::RespondToModel( - "write_stdin handler received unsupported payload".to_string(), - )); - } - }; - - let args: WriteStdinArgs = parse_arguments(&arguments)?; - let max_output_tokens = - effective_max_output_tokens(args.max_output_tokens, turn.truncation_policy); - let response = session - .services - .unified_exec_manager - .write_stdin(WriteStdinRequest { - process_id: args.session_id, - input: &args.chars, - yield_time_ms: args.yield_time_ms, - max_output_tokens: Some(max_output_tokens), - }) - .await - .map_err(|err| { - FunctionCallError::RespondToModel(format!("write_stdin failed: {err}")) - })?; - - let interaction = TerminalInteractionEvent { - call_id: response.event_call_id.clone(), - process_id: args.session_id.to_string(), - stdin: args.chars.clone(), - }; - session - .send_event(turn.as_ref(), EventMsg::TerminalInteraction(interaction)) - .await; - - Ok(response) - } -} - fn post_unified_exec_tool_use_payload( invocation: &ToolInvocation, result: &ExecCommandToolOutput, @@ -470,14 +101,6 @@ fn post_unified_exec_tool_use_payload( }) } -fn emit_unified_exec_tty_metric(session_telemetry: &SessionTelemetry, tty: bool) { - session_telemetry.counter( - TOOL_CALL_UNIFIED_EXEC_METRIC, - /*inc*/ 1, - &[("tty", if tty { "true" } else { "false" })], - ); -} - pub(crate) fn get_command( args: &ExecCommandArgs, session_shell: Arc, diff --git a/codex-rs/core/src/tools/handlers/unified_exec/exec_command.rs b/codex-rs/core/src/tools/handlers/unified_exec/exec_command.rs new file mode 100644 index 000000000000..75ae3fea29e6 --- /dev/null +++ b/codex-rs/core/src/tools/handlers/unified_exec/exec_command.rs @@ -0,0 +1,309 @@ +use std::sync::Arc; + +use crate::function_tool::FunctionCallError; +use crate::maybe_emit_implicit_skill_invocation; +use crate::tools::context::ExecCommandToolOutput; +use crate::tools::context::ToolInvocation; +use crate::tools::context::ToolPayload; +use crate::tools::handlers::apply_granted_turn_permissions; +use crate::tools::handlers::apply_patch::intercept_apply_patch; +use crate::tools::handlers::implicit_granted_permissions; +use crate::tools::handlers::normalize_and_validate_additional_permissions; +use crate::tools::handlers::parse_arguments; +use crate::tools::handlers::parse_arguments_with_base_path; +use crate::tools::handlers::resolve_tool_environment; +use crate::tools::hook_names::HookToolName; +use crate::tools::registry::PostToolUsePayload; +use crate::tools::registry::PreToolUsePayload; +use crate::tools::registry::ToolHandler; +use crate::tools::registry::ToolKind; +use crate::unified_exec::ExecCommandRequest; +use crate::unified_exec::UnifiedExecContext; +use crate::unified_exec::UnifiedExecError; +use crate::unified_exec::UnifiedExecProcessManager; +use crate::unified_exec::generate_chunk_id; +use codex_features::Feature; +use codex_otel::SessionTelemetry; +use codex_otel::TOOL_CALL_UNIFIED_EXEC_METRIC; +use codex_shell_command::is_safe_command::is_known_safe_command; +use codex_tools::ToolName; +use codex_utils_output_truncation::approx_token_count; + +use super::ExecCommandArgs; +use super::ExecCommandEnvironmentArgs; +use super::effective_max_output_tokens; +use super::get_command; +use super::post_unified_exec_tool_use_payload; + +pub struct ExecCommandHandler; + +impl ToolHandler for ExecCommandHandler { + type Output = ExecCommandToolOutput; + + fn tool_name(&self) -> ToolName { + ToolName::plain("exec_command") + } + + fn kind(&self) -> ToolKind { + ToolKind::Function + } + + fn matches_kind(&self, payload: &ToolPayload) -> bool { + matches!(payload, ToolPayload::Function { .. }) + } + + async fn is_mutating(&self, invocation: &ToolInvocation) -> bool { + let ToolPayload::Function { arguments } = &invocation.payload else { + tracing::error!( + "This should never happen, invocation payload is wrong: {:?}", + invocation.payload + ); + return true; + }; + + let Ok(params) = parse_arguments::(arguments) else { + return true; + }; + let command = match get_command( + ¶ms, + invocation.session.user_shell(), + &invocation.turn.tools_config.unified_exec_shell_mode, + invocation.turn.tools_config.allow_login_shell, + ) { + Ok(command) => command, + Err(_) => return true, + }; + !is_known_safe_command(&command) + } + + fn pre_tool_use_payload(&self, invocation: &ToolInvocation) -> Option { + let ToolPayload::Function { arguments } = &invocation.payload else { + return None; + }; + + parse_arguments::(arguments) + .ok() + .map(|args| PreToolUsePayload { + tool_name: HookToolName::bash(), + tool_input: serde_json::json!({ "command": args.cmd }), + }) + } + + fn post_tool_use_payload( + &self, + invocation: &ToolInvocation, + result: &Self::Output, + ) -> Option { + post_unified_exec_tool_use_payload(invocation, result) + } + + async fn handle(&self, invocation: ToolInvocation) -> Result { + let ToolInvocation { + session, + turn, + tracker, + call_id, + payload, + .. + } = invocation; + + let arguments = match payload { + ToolPayload::Function { arguments } => arguments, + _ => { + return Err(FunctionCallError::RespondToModel( + "exec_command handler received unsupported payload".to_string(), + )); + } + }; + + let manager: &UnifiedExecProcessManager = &session.services.unified_exec_manager; + let context = UnifiedExecContext::new(session.clone(), turn.clone(), call_id.clone()); + let environment_args: ExecCommandEnvironmentArgs = parse_arguments(&arguments)?; + let Some(turn_environment) = + resolve_tool_environment(turn.as_ref(), environment_args.environment_id.as_deref())? + else { + return Err(FunctionCallError::RespondToModel( + "unified exec is unavailable in this session".to_string(), + )); + }; + let cwd = environment_args + .workdir + .as_deref() + .filter(|workdir| !workdir.is_empty()) + .map_or_else( + || turn_environment.cwd.clone(), + |workdir| turn_environment.cwd.join(workdir), + ); + let environment = Arc::clone(&turn_environment.environment); + let fs = environment.get_filesystem(); + let args: ExecCommandArgs = parse_arguments_with_base_path(&arguments, &cwd)?; + let hook_command = args.cmd.clone(); + maybe_emit_implicit_skill_invocation( + session.as_ref(), + context.turn.as_ref(), + &hook_command, + &cwd, + ) + .await; + let process_id = manager.allocate_process_id().await; + let command = get_command( + &args, + session.user_shell(), + &turn.tools_config.unified_exec_shell_mode, + turn.tools_config.allow_login_shell, + ) + .map_err(FunctionCallError::RespondToModel)?; + let command_for_display = codex_shell_command::parse_command::shlex_join(&command); + + let ExecCommandArgs { + tty, + yield_time_ms, + max_output_tokens, + sandbox_permissions, + additional_permissions, + justification, + prefix_rule, + .. + } = args; + let max_output_tokens = + effective_max_output_tokens(max_output_tokens, turn.truncation_policy); + + let exec_permission_approvals_enabled = + session.features().enabled(Feature::ExecPermissionApprovals); + let requested_additional_permissions = additional_permissions.clone(); + let effective_additional_permissions = apply_granted_turn_permissions( + context.session.as_ref(), + cwd.as_path(), + sandbox_permissions, + additional_permissions, + ) + .await; + let additional_permissions_allowed = exec_permission_approvals_enabled + || (session.features().enabled(Feature::RequestPermissionsTool) + && effective_additional_permissions.permissions_preapproved); + + // Sticky turn permissions have already been approved, so they should + // continue through the normal exec approval flow for the command. + if effective_additional_permissions + .sandbox_permissions + .requests_sandbox_override() + && !effective_additional_permissions.permissions_preapproved + && !matches!( + context.turn.approval_policy.value(), + codex_protocol::protocol::AskForApproval::OnRequest + ) + { + let approval_policy = context.turn.approval_policy.value(); + manager.release_process_id(process_id).await; + return Err(FunctionCallError::RespondToModel(format!( + "approval policy is {approval_policy:?}; reject command — you cannot ask for escalated permissions if the approval policy is {approval_policy:?}" + ))); + } + + let normalized_additional_permissions = match implicit_granted_permissions( + sandbox_permissions, + requested_additional_permissions.as_ref(), + &effective_additional_permissions, + ) + .map_or_else( + || { + normalize_and_validate_additional_permissions( + additional_permissions_allowed, + context.turn.approval_policy.value(), + effective_additional_permissions.sandbox_permissions, + effective_additional_permissions.additional_permissions, + effective_additional_permissions.permissions_preapproved, + &cwd, + ) + }, + |permissions| Ok(Some(permissions)), + ) { + Ok(normalized) => normalized, + Err(err) => { + manager.release_process_id(process_id).await; + return Err(FunctionCallError::RespondToModel(err)); + } + }; + + if let Some(output) = intercept_apply_patch( + &command, + &cwd, + fs.as_ref(), + context.session.clone(), + context.turn.clone(), + Some(&tracker), + &context.call_id, + "exec_command", + ) + .await? + { + manager.release_process_id(process_id).await; + return Ok(ExecCommandToolOutput { + event_call_id: String::new(), + chunk_id: String::new(), + wall_time: std::time::Duration::ZERO, + raw_output: output.into_text().into_bytes(), + max_output_tokens: Some(max_output_tokens), + process_id: None, + exit_code: None, + original_token_count: None, + hook_command: None, + }); + } + + emit_unified_exec_tty_metric(&turn.session_telemetry, tty); + match manager + .exec_command( + ExecCommandRequest { + command, + hook_command: hook_command.clone(), + process_id, + yield_time_ms, + max_output_tokens: Some(max_output_tokens), + cwd, + environment, + network: context.turn.network.clone(), + tty, + sandbox_permissions: effective_additional_permissions.sandbox_permissions, + additional_permissions: normalized_additional_permissions, + additional_permissions_preapproved: effective_additional_permissions + .permissions_preapproved, + justification, + prefix_rule, + }, + &context, + ) + .await + { + Ok(response) => Ok(response), + Err(UnifiedExecError::SandboxDenied { output, .. }) => { + let output_text = output.aggregated_output.text; + let original_token_count = approx_token_count(&output_text); + Ok(ExecCommandToolOutput { + event_call_id: context.call_id.clone(), + chunk_id: generate_chunk_id(), + wall_time: output.duration, + raw_output: output_text.into_bytes(), + max_output_tokens: Some(max_output_tokens), + // Sandbox denial is terminal, so there is no live + // process for write_stdin to resume. + process_id: None, + exit_code: Some(output.exit_code), + original_token_count: Some(original_token_count), + hook_command: Some(hook_command), + }) + } + Err(err) => Err(FunctionCallError::RespondToModel(format!( + "exec_command failed for `{command_for_display}`: {err:?}" + ))), + } + } +} + +fn emit_unified_exec_tty_metric(session_telemetry: &SessionTelemetry, tty: bool) { + session_telemetry.counter( + TOOL_CALL_UNIFIED_EXEC_METRIC, + /*inc*/ 1, + &[("tty", if tty { "true" } else { "false" })], + ); +} diff --git a/codex-rs/core/src/tools/handlers/unified_exec/write_stdin.rs b/codex-rs/core/src/tools/handlers/unified_exec/write_stdin.rs new file mode 100644 index 000000000000..1e9c68f227ff --- /dev/null +++ b/codex-rs/core/src/tools/handlers/unified_exec/write_stdin.rs @@ -0,0 +1,104 @@ +use crate::function_tool::FunctionCallError; +use crate::tools::context::ExecCommandToolOutput; +use crate::tools::context::ToolInvocation; +use crate::tools::context::ToolPayload; +use crate::tools::handlers::parse_arguments; +use crate::tools::registry::PostToolUsePayload; +use crate::tools::registry::ToolHandler; +use crate::tools::registry::ToolKind; +use crate::unified_exec::WriteStdinRequest; +use codex_protocol::protocol::EventMsg; +use codex_protocol::protocol::TerminalInteractionEvent; +use codex_tools::ToolName; +use serde::Deserialize; + +use super::effective_max_output_tokens; +use super::post_unified_exec_tool_use_payload; + +#[derive(Debug, Deserialize)] +struct WriteStdinArgs { + // The model is trained on `session_id`. + session_id: i32, + #[serde(default)] + chars: String, + #[serde(default = "super::default_write_stdin_yield_time_ms")] + yield_time_ms: u64, + #[serde(default)] + max_output_tokens: Option, +} + +pub struct WriteStdinHandler; + +impl ToolHandler for WriteStdinHandler { + type Output = ExecCommandToolOutput; + + fn tool_name(&self) -> ToolName { + ToolName::plain("write_stdin") + } + + fn kind(&self) -> ToolKind { + ToolKind::Function + } + + fn matches_kind(&self, payload: &ToolPayload) -> bool { + matches!(payload, ToolPayload::Function { .. }) + } + + async fn is_mutating(&self, _invocation: &ToolInvocation) -> bool { + true + } + + fn post_tool_use_payload( + &self, + invocation: &ToolInvocation, + result: &Self::Output, + ) -> Option { + post_unified_exec_tool_use_payload(invocation, result) + } + + async fn handle(&self, invocation: ToolInvocation) -> Result { + let ToolInvocation { + session, + turn, + payload, + .. + } = invocation; + + let arguments = match payload { + ToolPayload::Function { arguments } => arguments, + _ => { + return Err(FunctionCallError::RespondToModel( + "write_stdin handler received unsupported payload".to_string(), + )); + } + }; + + let args: WriteStdinArgs = parse_arguments(&arguments)?; + let max_output_tokens = + effective_max_output_tokens(args.max_output_tokens, turn.truncation_policy); + let response = session + .services + .unified_exec_manager + .write_stdin(WriteStdinRequest { + process_id: args.session_id, + input: &args.chars, + yield_time_ms: args.yield_time_ms, + max_output_tokens: Some(max_output_tokens), + }) + .await + .map_err(|err| { + FunctionCallError::RespondToModel(format!("write_stdin failed: {err}")) + })?; + + let interaction = TerminalInteractionEvent { + call_id: response.event_call_id.clone(), + process_id: args.session_id.to_string(), + stdin: args.chars.clone(), + }; + session + .send_event(turn.as_ref(), EventMsg::TerminalInteraction(interaction)) + .await; + + Ok(response) + } +}