diff --git a/codex-rs/core/src/tools/context.rs b/codex-rs/core/src/tools/context.rs index 15ada81bb96f..feb4b02d71b7 100644 --- a/codex-rs/core/src/tools/context.rs +++ b/codex-rs/core/src/tools/context.rs @@ -370,7 +370,7 @@ pub struct ExecCommandToolOutput { pub process_id: Option, pub exit_code: Option, pub original_token_count: Option, - pub session_command: Option>, + pub hook_command: Option, } impl ToolOutput for ExecCommandToolOutput { @@ -394,7 +394,7 @@ impl ToolOutput for ExecCommandToolOutput { } fn post_tool_use_response(&self, _call_id: &str, _payload: &ToolPayload) -> Option { - if self.process_id.is_some() || self.session_command.is_none() { + if self.process_id.is_some() || self.hook_command.is_none() { return None; } diff --git a/codex-rs/core/src/tools/context_tests.rs b/codex-rs/core/src/tools/context_tests.rs index 4b903347aba4..60ad4561519f 100644 --- a/codex-rs/core/src/tools/context_tests.rs +++ b/codex-rs/core/src/tools/context_tests.rs @@ -390,7 +390,7 @@ fn exec_command_tool_output_formats_truncated_response() { process_id: None, exit_code: Some(0), original_token_count: Some(10), - session_command: None, + hook_command: None, } .to_response_item("call-42", &payload); diff --git a/codex-rs/core/src/tools/handlers/apply_patch.rs b/codex-rs/core/src/tools/handlers/apply_patch.rs index 8eca2caf3f07..0ccc62d9e988 100644 --- a/codex-rs/core/src/tools/handlers/apply_patch.rs +++ b/codex-rs/core/src/tools/handlers/apply_patch.rs @@ -324,11 +324,12 @@ impl ToolHandler for ApplyPatchHandler { &self, call_id: &str, payload: &ToolPayload, - result: &dyn ToolOutput, + result: &Self::Output, ) -> Option { let tool_response = result.post_tool_use_response(call_id, payload)?; Some(PostToolUsePayload { tool_name: HookToolName::apply_patch(), + tool_use_id: call_id.to_string(), command: apply_patch_payload_command(payload)?, tool_response, }) diff --git a/codex-rs/core/src/tools/handlers/apply_patch_tests.rs b/codex-rs/core/src/tools/handlers/apply_patch_tests.rs index 979b4d4e96d2..39c313310316 100644 --- a/codex-rs/core/src/tools/handlers/apply_patch_tests.rs +++ b/codex-rs/core/src/tools/handlers/apply_patch_tests.rs @@ -90,6 +90,7 @@ fn post_tool_use_payload_uses_patch_input_and_tool_output() { handler.post_tool_use_payload("call-apply-patch", &payload, &output), Some(PostToolUsePayload { tool_name: HookToolName::apply_patch(), + tool_use_id: "call-apply-patch".to_string(), command: patch.to_string(), tool_response: json!("Success. Updated files."), }) diff --git a/codex-rs/core/src/tools/handlers/shell.rs b/codex-rs/core/src/tools/handlers/shell.rs index 00343d3ceaf3..b9718173b7d9 100644 --- a/codex-rs/core/src/tools/handlers/shell.rs +++ b/codex-rs/core/src/tools/handlers/shell.rs @@ -216,11 +216,12 @@ impl ToolHandler for ShellHandler { &self, call_id: &str, payload: &ToolPayload, - result: &dyn ToolOutput, + result: &Self::Output, ) -> Option { let tool_response = result.post_tool_use_response(call_id, payload)?; Some(PostToolUsePayload { tool_name: HookToolName::bash(), + tool_use_id: call_id.to_string(), command: shell_payload_command(payload)?, tool_response, }) @@ -328,11 +329,12 @@ impl ToolHandler for ShellCommandHandler { &self, call_id: &str, payload: &ToolPayload, - result: &dyn ToolOutput, + result: &Self::Output, ) -> Option { let tool_response = result.post_tool_use_response(call_id, payload)?; Some(PostToolUsePayload { tool_name: HookToolName::bash(), + tool_use_id: call_id.to_string(), command: shell_command_payload_command(payload)?, tool_response, }) diff --git a/codex-rs/core/src/tools/handlers/shell_tests.rs b/codex-rs/core/src/tools/handlers/shell_tests.rs index b36dd0225c97..c13d834c061d 100644 --- a/codex-rs/core/src/tools/handlers/shell_tests.rs +++ b/codex-rs/core/src/tools/handlers/shell_tests.rs @@ -284,6 +284,7 @@ fn build_post_tool_use_payload_uses_tool_output_wire_value() { handler.post_tool_use_payload("call-42", &payload, &output), Some(crate::tools::registry::PostToolUsePayload { tool_name: HookToolName::bash(), + tool_use_id: "call-42".to_string(), command: "printf shell command".to_string(), tool_response: json!("shell output"), }) diff --git a/codex-rs/core/src/tools/handlers/unified_exec.rs b/codex-rs/core/src/tools/handlers/unified_exec.rs index b25f5a69b791..813f4f862806 100644 --- a/codex-rs/core/src/tools/handlers/unified_exec.rs +++ b/codex-rs/core/src/tools/handlers/unified_exec.rs @@ -147,21 +147,23 @@ impl ToolHandler for UnifiedExecHandler { &self, call_id: &str, payload: &ToolPayload, - result: &dyn ToolOutput, + result: &Self::Output, ) -> Option { - let ToolPayload::Function { arguments } = payload else { + let ToolPayload::Function { .. } = payload else { return None; }; - let args = parse_arguments::(arguments).ok()?; - if args.tty { - return None; - } - - let tool_response = result.post_tool_use_response(call_id, payload)?; + let command = result.hook_command.clone()?; + let tool_use_id = if result.event_call_id.is_empty() { + call_id.to_string() + } else { + result.event_call_id.clone() + }; + let tool_response = result.post_tool_use_response(&tool_use_id, payload)?; Some(PostToolUsePayload { tool_name: HookToolName::bash(), - command: args.cmd, + tool_use_id, + command, tool_response, }) } @@ -200,11 +202,12 @@ impl ToolHandler for UnifiedExecHandler { "exec_command" => { let cwd = resolve_workdir_base_path(&arguments, &context.turn.cwd)?; let args: ExecCommandArgs = parse_arguments_with_base_path(&arguments, &cwd)?; + let hook_command = args.cmd.clone(); let workdir = context.turn.resolve_path(args.workdir.clone()); maybe_emit_implicit_skill_invocation( session.as_ref(), context.turn.as_ref(), - &args.cmd, + &hook_command, &workdir, ) .await; @@ -313,17 +316,16 @@ impl ToolHandler for UnifiedExecHandler { process_id: None, exit_code: None, original_token_count: None, - session_command: None, + hook_command: None, }); } emit_unified_exec_tty_metric(&turn.session_telemetry, tty); - let session_command = command.clone(); match manager .exec_command( ExecCommandRequest { command, - hook_command: args.cmd, + hook_command: hook_command.clone(), process_id, yield_time_ms, max_output_tokens, @@ -357,7 +359,7 @@ impl ToolHandler for UnifiedExecHandler { process_id: None, exit_code: Some(output.exit_code), original_token_count: Some(original_token_count), - session_command: Some(session_command), + hook_command: Some(hook_command), } } Err(err) => { diff --git a/codex-rs/core/src/tools/handlers/unified_exec_tests.rs b/codex-rs/core/src/tools/handlers/unified_exec_tests.rs index 26f86bd66df9..0a5abe51d74d 100644 --- a/codex-rs/core/src/tools/handlers/unified_exec_tests.rs +++ b/codex-rs/core/src/tools/handlers/unified_exec_tests.rs @@ -252,7 +252,7 @@ fn exec_command_post_tool_use_payload_uses_output_for_noninteractive_one_shot_co arguments: serde_json::json!({ "cmd": "echo three", "tty": false }).to_string(), }; let output = ExecCommandToolOutput { - event_call_id: "event-43".to_string(), + event_call_id: "call-43".to_string(), chunk_id: "chunk-1".to_string(), wall_time: std::time::Duration::from_millis(498), raw_output: b"three".to_vec(), @@ -260,17 +260,14 @@ fn exec_command_post_tool_use_payload_uses_output_for_noninteractive_one_shot_co process_id: None, exit_code: Some(0), original_token_count: None, - session_command: Some(vec![ - "/bin/zsh".to_string(), - "-lc".to_string(), - "echo three".to_string(), - ]), + hook_command: Some("echo three".to_string()), }; assert_eq!( UnifiedExecHandler.post_tool_use_payload("call-43", &payload, &output), Some(crate::tools::registry::PostToolUsePayload { tool_name: HookToolName::bash(), + tool_use_id: "call-43".to_string(), command: "echo three".to_string(), tool_response: serde_json::json!("three"), }) @@ -278,12 +275,12 @@ fn exec_command_post_tool_use_payload_uses_output_for_noninteractive_one_shot_co } #[test] -fn exec_command_post_tool_use_payload_skips_interactive_exec() { +fn exec_command_post_tool_use_payload_uses_output_for_interactive_completion() { let payload = ToolPayload::Function { arguments: serde_json::json!({ "cmd": "echo three", "tty": true }).to_string(), }; let output = ExecCommandToolOutput { - event_call_id: "event-44".to_string(), + event_call_id: "call-44".to_string(), chunk_id: "chunk-1".to_string(), wall_time: std::time::Duration::from_millis(498), raw_output: b"three".to_vec(), @@ -291,16 +288,17 @@ fn exec_command_post_tool_use_payload_skips_interactive_exec() { process_id: None, exit_code: Some(0), original_token_count: None, - session_command: Some(vec![ - "/bin/zsh".to_string(), - "-lc".to_string(), - "echo three".to_string(), - ]), + hook_command: Some("echo three".to_string()), }; assert_eq!( UnifiedExecHandler.post_tool_use_payload("call-44", &payload, &output), - None + Some(crate::tools::registry::PostToolUsePayload { + tool_name: HookToolName::bash(), + tool_use_id: "call-44".to_string(), + command: "echo three".to_string(), + tool_response: serde_json::json!("three"), + }) ); } @@ -318,11 +316,7 @@ fn exec_command_post_tool_use_payload_skips_running_sessions() { process_id: Some(45), exit_code: None, original_token_count: None, - session_command: Some(vec![ - "/bin/zsh".to_string(), - "-lc".to_string(), - "echo three".to_string(), - ]), + hook_command: Some("echo three".to_string()), }; assert_eq!( @@ -330,3 +324,87 @@ fn exec_command_post_tool_use_payload_skips_running_sessions() { None ); } + +#[test] +fn write_stdin_post_tool_use_payload_uses_original_exec_call_id_and_command_on_completion() { + let payload = ToolPayload::Function { + arguments: serde_json::json!({ + "session_id": 45, + "chars": "", + }) + .to_string(), + }; + let output = ExecCommandToolOutput { + event_call_id: "exec-call-45".to_string(), + chunk_id: "chunk-2".to_string(), + wall_time: std::time::Duration::from_millis(498), + raw_output: b"finished\n".to_vec(), + max_output_tokens: None, + process_id: None, + exit_code: Some(0), + original_token_count: None, + hook_command: Some("sleep 1; echo finished".to_string()), + }; + + assert_eq!( + UnifiedExecHandler.post_tool_use_payload("write-stdin-call", &payload, &output), + Some(crate::tools::registry::PostToolUsePayload { + tool_name: HookToolName::bash(), + tool_use_id: "exec-call-45".to_string(), + command: "sleep 1; echo finished".to_string(), + tool_response: serde_json::json!("finished\n"), + }) + ); +} + +#[test] +fn write_stdin_post_tool_use_payload_keeps_parallel_session_metadata_separate() { + let payload = ToolPayload::Function { + arguments: serde_json::json!({ "session_id": 45, "chars": "" }).to_string(), + }; + let output_a = ExecCommandToolOutput { + event_call_id: "exec-call-a".to_string(), + chunk_id: "chunk-a".to_string(), + wall_time: std::time::Duration::from_millis(498), + raw_output: b"alpha\n".to_vec(), + max_output_tokens: None, + process_id: None, + exit_code: Some(0), + original_token_count: None, + hook_command: Some("sleep 2; echo alpha".to_string()), + }; + let output_b = ExecCommandToolOutput { + event_call_id: "exec-call-b".to_string(), + chunk_id: "chunk-b".to_string(), + wall_time: std::time::Duration::from_millis(498), + raw_output: b"beta\n".to_vec(), + max_output_tokens: None, + process_id: None, + exit_code: Some(0), + original_token_count: None, + hook_command: Some("sleep 1; echo beta".to_string()), + }; + + let payloads = [ + UnifiedExecHandler.post_tool_use_payload("write-call-b", &payload, &output_b), + UnifiedExecHandler.post_tool_use_payload("write-call-a", &payload, &output_a), + ]; + + assert_eq!( + payloads, + [ + Some(crate::tools::registry::PostToolUsePayload { + tool_name: HookToolName::bash(), + tool_use_id: "exec-call-b".to_string(), + command: "sleep 1; echo beta".to_string(), + tool_response: serde_json::json!("beta\n"), + }), + Some(crate::tools::registry::PostToolUsePayload { + tool_name: HookToolName::bash(), + tool_use_id: "exec-call-a".to_string(), + command: "sleep 2; echo alpha".to_string(), + tool_response: serde_json::json!("alpha\n"), + }), + ] + ); +} diff --git a/codex-rs/core/src/tools/parallel.rs b/codex-rs/core/src/tools/parallel.rs index 06cf7e5eed0e..384a800b4b11 100644 --- a/codex-rs/core/src/tools/parallel.rs +++ b/codex-rs/core/src/tools/parallel.rs @@ -178,6 +178,7 @@ impl ToolCallRuntime { result: Box::new(AbortedToolOutput { message: Self::abort_message(call, secs), }), + post_tool_use_payload: None, } } diff --git a/codex-rs/core/src/tools/registry.rs b/codex-rs/core/src/tools/registry.rs index d70550d0e542..678e620bd865 100644 --- a/codex-rs/core/src/tools/registry.rs +++ b/codex-rs/core/src/tools/registry.rs @@ -72,7 +72,7 @@ pub trait ToolHandler: Send + Sync { &self, _call_id: &str, _payload: &ToolPayload, - _result: &dyn ToolOutput, + _result: &Self::Output, ) -> Option { None } @@ -107,6 +107,7 @@ pub(crate) struct AnyToolResult { pub(crate) call_id: String, pub(crate) payload: ToolPayload, pub(crate) result: Box, + pub(crate) post_tool_use_payload: Option, } impl AnyToolResult { @@ -143,9 +144,11 @@ pub(crate) struct PreToolUsePayload { pub(crate) struct PostToolUsePayload { /// Hook-facing tool name model. /// - /// Keep this aligned with the corresponding pre-use payload so external - /// hook consumers can pair events by `tool_use_id`. + /// The canonical name is serialized to hook stdin, while aliases are used + /// only for matcher compatibility. pub(crate) tool_name: HookToolName, + /// The originating tool-use id exposed at `tool_use_id`. + pub(crate) tool_use_id: String, /// Command-shaped input exposed at `tool_input.command`. pub(crate) command: String, /// Tool result exposed at `tool_response`. @@ -159,15 +162,7 @@ trait AnyToolHandler: Send + Sync { fn pre_tool_use_payload(&self, invocation: &ToolInvocation) -> Option; - fn post_tool_use_payload( - &self, - call_id: &str, - payload: &ToolPayload, - result: &dyn ToolOutput, - ) -> Option; - fn create_diff_consumer(&self) -> Option>; - fn handle_any<'a>( &'a self, invocation: ToolInvocation, @@ -190,19 +185,9 @@ where ToolHandler::pre_tool_use_payload(self, invocation) } - fn post_tool_use_payload( - &self, - call_id: &str, - payload: &ToolPayload, - result: &dyn ToolOutput, - ) -> Option { - ToolHandler::post_tool_use_payload(self, call_id, payload, result) - } - fn create_diff_consumer(&self) -> Option> { ToolHandler::create_diff_consumer(self) } - fn handle_any<'a>( &'a self, invocation: ToolInvocation, @@ -211,10 +196,13 @@ where let call_id = invocation.call_id.clone(); let payload = invocation.payload.clone(); let output = self.handle(invocation).await?; + let post_tool_use_payload = + ToolHandler::post_tool_use_payload(self, &call_id, &payload, &output); Ok(AnyToolResult { call_id, payload, result: Box::new(output), + post_tool_use_payload, }) }) } @@ -400,13 +388,9 @@ impl ToolRegistry { emit_metric_for_tool_read(&invocation, success).await; let post_tool_use_payload = if success { let guard = response_cell.lock().await; - guard.as_ref().and_then(|result| { - handler.post_tool_use_payload( - &result.call_id, - &result.payload, - result.result.as_ref(), - ) - }) + guard + .as_ref() + .and_then(|result| result.post_tool_use_payload.clone()) } else { None }; @@ -415,7 +399,7 @@ impl ToolRegistry { run_post_tool_use_hooks( &invocation.session, &invocation.turn, - invocation.call_id.clone(), + post_tool_use_payload.tool_use_id, post_tool_use_payload.tool_name.name().to_string(), post_tool_use_payload.tool_name.matcher_aliases().to_vec(), post_tool_use_payload.command, diff --git a/codex-rs/core/src/unified_exec/mod.rs b/codex-rs/core/src/unified_exec/mod.rs index 01c8914bc24c..a5a6e69f896d 100644 --- a/codex-rs/core/src/unified_exec/mod.rs +++ b/codex-rs/core/src/unified_exec/mod.rs @@ -148,7 +148,7 @@ struct ProcessEntry { process: Arc, call_id: String, process_id: i32, - command: Vec, + hook_command: String, tty: bool, network_approval_id: Option, session: Weak, diff --git a/codex-rs/core/src/unified_exec/mod_tests.rs b/codex-rs/core/src/unified_exec/mod_tests.rs index 1865188d3e7d..f96498b5a855 100644 --- a/codex-rs/core/src/unified_exec/mod_tests.rs +++ b/codex-rs/core/src/unified_exec/mod_tests.rs @@ -113,7 +113,7 @@ async fn exec_command_with_tty( process: Arc::clone(&process), call_id: context.call_id.clone(), process_id, - command: command.clone(), + hook_command: cmd.to_string(), tty, network_approval_id: None, session: Arc::downgrade(session), @@ -165,7 +165,7 @@ async fn exec_command_with_tty( process_id: response_process_id, exit_code, original_token_count: Some(approx_token_count(&text)), - session_command: Some(command), + hook_command: Some(cmd.to_string()), }) } diff --git a/codex-rs/core/src/unified_exec/process_manager.rs b/codex-rs/core/src/unified_exec/process_manager.rs index ef1f3bd81cfa..ec0b00cf590b 100644 --- a/codex-rs/core/src/unified_exec/process_manager.rs +++ b/codex-rs/core/src/unified_exec/process_manager.rs @@ -169,7 +169,7 @@ struct PreparedProcessHandles { output_closed_notify: Arc, cancellation_token: CancellationToken, pause_state: Option>, - command: Vec, + hook_command: String, process_id: i32, tty: bool, } @@ -279,6 +279,7 @@ impl UnifiedExecProcessManager { Arc::clone(&process), context, &request.command, + request.hook_command.clone(), cwd.clone(), start, request.process_id, @@ -398,7 +399,7 @@ impl UnifiedExecProcessManager { process_id: response_process_id, exit_code, original_token_count: Some(original_token_count), - session_command: Some(request.command.clone()), + hook_command: Some(request.hook_command.clone()), }; Ok(response) @@ -418,7 +419,7 @@ impl UnifiedExecProcessManager { output_closed_notify, cancellation_token, pause_state, - command: session_command, + hook_command, process_id, tty, .. @@ -517,7 +518,7 @@ impl UnifiedExecProcessManager { process_id, exit_code, original_token_count: Some(original_token_count), - session_command: Some(session_command.clone()), + hook_command: Some(hook_command), }; Ok(response) @@ -585,7 +586,7 @@ impl UnifiedExecProcessManager { output_closed_notify, cancellation_token, pause_state, - command: entry.command.clone(), + hook_command: entry.hook_command.clone(), process_id: entry.process_id, tty: entry.tty, }) @@ -597,6 +598,7 @@ impl UnifiedExecProcessManager { process: Arc, context: &UnifiedExecContext, command: &[String], + hook_command: String, cwd: AbsolutePathBuf, started_at: Instant, process_id: i32, @@ -608,7 +610,7 @@ impl UnifiedExecProcessManager { process: Arc::clone(&process), call_id: context.call_id.clone(), process_id, - command: command.to_vec(), + hook_command, tty, network_approval_id, session: Arc::downgrade(&context.session), diff --git a/codex-rs/core/tests/suite/hooks.rs b/codex-rs/core/tests/suite/hooks.rs index 4c96de85fa81..022c528753df 100644 --- a/codex-rs/core/tests/suite/hooks.rs +++ b/codex-rs/core/tests/suite/hooks.rs @@ -33,6 +33,7 @@ use core_test_support::responses::mount_sse_sequence; use core_test_support::responses::sse; use core_test_support::responses::start_mock_server; use core_test_support::skip_if_no_network; +use core_test_support::skip_if_windows; use core_test_support::streaming_sse::StreamingSseChunk; use core_test_support::streaming_sse::start_streaming_sse_server; use core_test_support::test_codex::test_codex; @@ -407,6 +408,64 @@ elif mode == "exit_2": Ok(()) } +fn write_logging_pre_and_blocking_post_tool_use_hooks(home: &Path, feedback: &str) -> Result<()> { + let pre_script_path = home.join("pre_tool_use_hook.py"); + let pre_log_path = home.join("pre_tool_use_hook_log.jsonl"); + let post_script_path = home.join("post_tool_use_hook.py"); + let post_log_path = home.join("post_tool_use_hook_log.jsonl"); + let feedback_json = + serde_json::to_string(feedback).context("serialize post tool use feedback")?; + let pre_script = format!( + r#"import json +from pathlib import Path +import sys + +payload = json.load(sys.stdin) +with Path(r"{pre_log_path}").open("a", encoding="utf-8") as handle: + handle.write(json.dumps(payload) + "\n") +"#, + pre_log_path = pre_log_path.display(), + ); + let post_script = format!( + r#"import json +from pathlib import Path +import sys + +payload = json.load(sys.stdin) +with Path(r"{post_log_path}").open("a", encoding="utf-8") as handle: + handle.write(json.dumps(payload) + "\n") +sys.stderr.write({feedback_json} + "\n") +raise SystemExit(2) +"#, + post_log_path = post_log_path.display(), + ); + let hooks = serde_json::json!({ + "hooks": { + "PreToolUse": [{ + "matcher": "Bash", + "hooks": [{ + "type": "command", + "command": format!("python3 {}", pre_script_path.display()), + "statusMessage": "running pre tool use hook", + }] + }], + "PostToolUse": [{ + "matcher": "Bash", + "hooks": [{ + "type": "command", + "command": format!("python3 {}", post_script_path.display()), + "statusMessage": "running post tool use hook", + }] + }] + } + }); + + fs::write(&pre_script_path, pre_script).context("write pre tool use hook script")?; + fs::write(&post_script_path, post_script).context("write post tool use hook script")?; + fs::write(home.join("hooks.json"), hooks.to_string()).context("write hooks.json")?; + Ok(()) +} + fn write_session_start_hook_recording_transcript(home: &Path) -> Result<()> { let script_path = home.join("session_start_hook.py"); let log_path = home.join("session_start_hook_log.jsonl"); @@ -2514,6 +2573,112 @@ async fn post_tool_use_exit_two_replaces_one_shot_exec_command_output_with_feedb Ok(()) } +#[tokio::test(flavor = "multi_thread", worker_threads = 2)] +async fn post_tool_use_blocks_when_exec_session_completes_via_write_stdin() -> Result<()> { + skip_if_no_network!(Ok(())); + skip_if_windows!(Ok(())); + + let server = start_mock_server().await; + let start_call_id = "posttooluse-exec-session-start"; + let poll_call_id = "posttooluse-exec-session-poll"; + let command = "sleep 1; printf session-post-hook-output".to_string(); + let start_args = serde_json::json!({ + "cmd": command, + "shell": "/bin/sh", + "login": false, + "tty": false, + "yield_time_ms": 250, + }); + let poll_args = serde_json::json!({ + "session_id": 1000, + "chars": "", + "yield_time_ms": 5_000, + }); + let feedback = "blocked by session post hook"; + let responses = mount_sse_sequence( + &server, + vec![ + sse(vec![ + ev_response_created("resp-1"), + core_test_support::responses::ev_function_call( + start_call_id, + "exec_command", + &serde_json::to_string(&start_args)?, + ), + ev_completed("resp-1"), + ]), + sse(vec![ + ev_response_created("resp-2"), + core_test_support::responses::ev_function_call( + poll_call_id, + "write_stdin", + &serde_json::to_string(&poll_args)?, + ), + ev_completed("resp-2"), + ]), + sse(vec![ + ev_response_created("resp-3"), + ev_assistant_message("msg-1", "session post hook observed"), + ev_completed("resp-3"), + ]), + ], + ) + .await; + + let mut builder = test_codex() + .with_pre_build_hook(|home| { + if let Err(error) = write_logging_pre_and_blocking_post_tool_use_hooks(home, feedback) { + panic!("failed to write tool use hook test fixture: {error}"); + } + }) + .with_config(|config| { + config.use_experimental_unified_exec_tool = true; + config + .features + .enable(Feature::CodexHooks) + .expect("test config should allow feature update"); + config + .features + .enable(Feature::UnifiedExec) + .expect("test config should allow feature update"); + }); + let test = builder.build(&server).await?; + + test.submit_turn("run the exec command session with post hook") + .await?; + + let requests = responses.requests(); + assert_eq!(requests.len(), 3); + let output_item = requests[2].function_call_output(poll_call_id); + let output = output_item + .get("output") + .and_then(Value::as_str) + .expect("write_stdin output string"); + assert_eq!(output, feedback); + + let pre_hook_inputs = read_pre_tool_use_hook_inputs(test.codex_home_path())?; + assert_eq!(pre_hook_inputs.len(), 1); + assert_eq!(pre_hook_inputs[0]["tool_name"], "Bash"); + assert_eq!(pre_hook_inputs[0]["tool_use_id"], start_call_id); + assert_eq!(pre_hook_inputs[0]["tool_input"]["command"], command); + + let post_hook_inputs = read_post_tool_use_hook_inputs(test.codex_home_path())?; + assert_eq!(post_hook_inputs.len(), 1); + assert_eq!(post_hook_inputs[0]["hook_event_name"], "PostToolUse"); + assert_eq!(post_hook_inputs[0]["tool_name"], "Bash"); + assert_eq!(post_hook_inputs[0]["tool_use_id"], start_call_id); + assert_eq!(post_hook_inputs[0]["tool_input"]["command"], command); + assert!( + post_hook_inputs[0]["tool_response"] + .as_str() + .is_some_and(|tool_response| tool_response.contains("session-post-hook-output")), + "PostToolUse should see the final session output, got {:?}", + post_hook_inputs[0]["tool_response"] + ); + + Ok(()) +} + #[tokio::test(flavor = "multi_thread", worker_threads = 2)] async fn post_tool_use_records_additional_context_for_apply_patch() -> Result<()> { skip_if_no_network!(Ok(())); diff --git a/codex-rs/hooks/src/events/common.rs b/codex-rs/hooks/src/events/common.rs index e74c2912dd48..80dd325378f7 100644 --- a/codex-rs/hooks/src/events/common.rs +++ b/codex-rs/hooks/src/events/common.rs @@ -109,7 +109,7 @@ pub(crate) fn matcher_pattern_for_event( } pub(crate) fn validate_matcher_pattern(matcher: &str) -> Result<(), regex::Error> { - if is_match_all_matcher(matcher) { + if is_match_all_matcher(matcher) || is_exact_matcher(matcher) { return Ok(()); } regex::Regex::new(matcher).map(|_| ()) @@ -119,6 +119,9 @@ pub(crate) fn matches_matcher(matcher: Option<&str>, input: Option<&str>) -> boo match matcher { None => true, Some(matcher) if is_match_all_matcher(matcher) => true, + Some(matcher) if is_exact_matcher(matcher) => input + .map(|input| matcher.split('|').any(|candidate| candidate == input)) + .unwrap_or(false), Some(matcher) => input .and_then(|input| { regex::Regex::new(matcher) @@ -144,6 +147,12 @@ fn is_match_all_matcher(matcher: &str) -> bool { matcher.is_empty() || matcher == "*" } +fn is_exact_matcher(matcher: &str) -> bool { + matcher + .chars() + .all(|ch| ch.is_ascii_alphanumeric() || ch == '_' || ch == '|') +} + #[cfg(test)] mod tests { use codex_protocol::protocol::HookEventName; @@ -181,6 +190,27 @@ mod tests { assert_eq!(validate_matcher_pattern("Edit|Write"), Ok(())); } + #[test] + fn matcher_exact_string_does_not_substring_match() { + assert!(matches_matcher(Some("Bash"), Some("Bash"))); + assert!(!matches_matcher(Some("Bash"), Some("BashOutput"))); + assert_eq!(validate_matcher_pattern("Bash"), Ok(())); + } + + #[test] + fn matcher_uses_regex_when_it_contains_other_characters() { + assert!(matches_matcher(Some("^Bash"), Some("BashOutput"))); + assert!(matches_matcher( + Some("mcp__memory__.*"), + Some("mcp__memory__create_entities") + )); + assert!(!matches_matcher( + Some("mcp__memory"), + Some("mcp__memory__create_entities") + )); + assert_eq!(validate_matcher_pattern("mcp__memory__.*"), Ok(())); + } + #[test] fn matcher_supports_anchored_regexes() { assert!(matches_matcher(Some("^Bash$"), Some("Bash")));