From cf0785e99bfbf5db3c63e28cc4630ee42681a9f0 Mon Sep 17 00:00:00 2001 From: pakrym-oai Date: Tue, 10 Mar 2026 19:48:06 -0700 Subject: [PATCH 1/7] codex: persist code mode runner sessions --- codex-rs/core/src/state/service.rs | 11 ++ codex-rs/core/src/tools/code_mode.rs | 148 ++++++++++++++----- codex-rs/core/src/tools/code_mode_runner.cjs | 69 +++++---- 3 files changed, 159 insertions(+), 69 deletions(-) diff --git a/codex-rs/core/src/state/service.rs b/codex-rs/core/src/state/service.rs index 5c0a741a126..0c436a921be 100644 --- a/codex-rs/core/src/state/service.rs +++ b/codex-rs/core/src/state/service.rs @@ -15,6 +15,7 @@ use crate::models_manager::manager::ModelsManager; use crate::plugins::PluginsManager; use crate::skills::SkillsManager; use crate::state_db::StateDbHandle; +use crate::tools::code_mode::CodeModeProcess; use crate::tools::network_approval::NetworkApprovalService; use crate::tools::runtimes::ExecveSessionApproval; use crate::tools::sandboxing::ApprovalStore; @@ -31,12 +32,14 @@ use tokio_util::sync::CancellationToken; pub(crate) struct CodeModeStoreService { stored_values: Mutex>, + process: Mutex>, } impl Default for CodeModeStoreService { fn default() -> Self { Self { stored_values: Mutex::new(HashMap::new()), + process: Mutex::new(None), } } } @@ -49,6 +52,14 @@ impl CodeModeStoreService { pub(crate) async fn replace_stored_values(&self, values: HashMap) { *self.stored_values.lock().await = values; } + + pub(crate) async fn store_process(&self, process: CodeModeProcess) { + *self.process.lock().await = Some(process); + } + + pub(crate) async fn take_process(&self) -> Option { + self.process.lock().await.take() + } } pub(crate) struct SessionServices { diff --git a/codex-rs/core/src/tools/code_mode.rs b/codex-rs/core/src/tools/code_mode.rs index ba8dd29e04f..955691c7844 100644 --- a/codex-rs/core/src/tools/code_mode.rs +++ b/codex-rs/core/src/tools/code_mode.rs @@ -30,6 +30,7 @@ use tokio::io::AsyncBufReadExt; use tokio::io::AsyncReadExt; use tokio::io::AsyncWriteExt; use tokio::io::BufReader; +use tokio::task::JoinHandle; const CODE_MODE_RUNNER_SOURCE: &str = include_str!("code_mode_runner.cjs"); const CODE_MODE_BRIDGE_SOURCE: &str = include_str!("code_mode_bridge.js"); @@ -42,6 +43,37 @@ struct ExecContext { tracker: SharedTurnDiffTracker, } +pub(crate) struct CodeModeProcess { + child: tokio::process::Child, + stdin: tokio::process::ChildStdin, + stdout_lines: tokio::io::Lines>, + stderr_task: Option>, +} + +impl CodeModeProcess { + fn has_exited(&mut self) -> Result { + self.child + .try_wait() + .map(|status| status.is_some()) + .map_err(|err| format!("failed to inspect {PUBLIC_TOOL_NAME} runner: {err}")) + } + + async fn wait_for_exit(&mut self) -> Result { + self.child + .wait() + .await + .map_err(|err| format!("failed to wait for {PUBLIC_TOOL_NAME} runner: {err}")) + } + + async fn stderr(&mut self) -> Result { + self.stderr_task + .take() + .ok_or_else(|| format!("{PUBLIC_TOOL_NAME} stderr collector missing"))? + .await + .map_err(|err| format!("failed to collect {PUBLIC_TOOL_NAME} stderr: {err}")) + } +} + #[derive(Clone, Copy, Debug, Deserialize, Eq, PartialEq, Serialize)] #[serde(rename_all = "snake_case")] enum CodeModeToolKind { @@ -63,12 +95,14 @@ struct EnabledTool { #[derive(Serialize)] #[serde(tag = "type", rename_all = "snake_case")] enum HostToNodeMessage { - Init { + Start { + session_id: String, enabled_tools: Vec, stored_values: HashMap, source: String, }, Response { + session_id: String, id: String, code_mode_result: JsonValue, }, @@ -78,12 +112,14 @@ enum HostToNodeMessage { #[serde(tag = "type", rename_all = "snake_case")] enum NodeToHostMessage { ToolCall { + session_id: String, id: String, name: String, #[serde(default)] input: Option, }, Result { + session_id: String, content_items: Vec, stored_values: HashMap, #[serde(default)] @@ -138,20 +174,33 @@ pub(crate) async fn execute( let enabled_tools = build_enabled_tools(&exec).await; let stored_values = exec.session.services.code_mode_store.stored_values().await; let source = build_source(&code, &enabled_tools).map_err(FunctionCallError::RespondToModel)?; - execute_node(exec, source, enabled_tools, stored_values) - .await - .map_err(FunctionCallError::RespondToModel) + let mut process = match exec.session.services.code_mode_store.take_process().await { + Some(mut process) => { + if matches!(process.has_exited(), Ok(false)) { + process + } else { + spawn_code_mode_process(&exec) + .await + .map_err(FunctionCallError::RespondToModel)? + } + } + None => spawn_code_mode_process(&exec) + .await + .map_err(FunctionCallError::RespondToModel)?, + }; + let result = execute_node(&exec, &mut process, source, enabled_tools, stored_values).await; + if result.is_ok() && matches!(process.has_exited(), Ok(false)) { + exec.session + .services + .code_mode_store + .store_process(process) + .await; + } + result.map_err(FunctionCallError::RespondToModel) } -async fn execute_node( - exec: ExecContext, - source: String, - enabled_tools: Vec, - stored_values: HashMap, -) -> Result { +async fn spawn_code_mode_process(exec: &ExecContext) -> Result { let node_path = resolve_compatible_node(exec.turn.config.js_repl_node_path.as_deref()).await?; - let started_at = std::time::Instant::now(); - let env = create_env(&exec.turn.shell_environment_policy, None); let mut cmd = tokio::process::Command::new(&node_path); cmd.arg("--experimental-vm-modules"); @@ -176,7 +225,7 @@ async fn execute_node( .stderr .take() .ok_or_else(|| format!("{PUBLIC_TOOL_NAME} runner missing stderr"))?; - let mut stdin = child + let stdin = child .stdin .take() .ok_or_else(|| format!("{PUBLIC_TOOL_NAME} runner missing stdin"))?; @@ -188,19 +237,38 @@ async fn execute_node( String::from_utf8_lossy(&buf).trim().to_string() }); + Ok(CodeModeProcess { + child, + stdin, + stdout_lines: BufReader::new(stdout).lines(), + stderr_task: Some(stderr_task), + }) +} + +async fn execute_node( + exec: &ExecContext, + process: &mut CodeModeProcess, + source: String, + enabled_tools: Vec, + stored_values: HashMap, +) -> Result { + let started_at = std::time::Instant::now(); + let session_id = uuid::Uuid::new_v4().to_string(); + write_message( - &mut stdin, - &HostToNodeMessage::Init { - enabled_tools: enabled_tools.clone(), + &mut process.stdin, + &HostToNodeMessage::Start { + session_id: session_id.clone(), + enabled_tools, stored_values, source, }, ) .await?; - let mut stdout_lines = BufReader::new(stdout).lines(); let mut pending_result = None; - while let Some(line) = stdout_lines + while let Some(line) = process + .stdout_lines .next_line() .await .map_err(|err| format!("failed to read {PUBLIC_TOOL_NAME} runner stdout: {err}"))? @@ -212,19 +280,36 @@ async fn execute_node( format!("invalid {PUBLIC_TOOL_NAME} runner message: {err}; line={line}") })?; match message { - NodeToHostMessage::ToolCall { id, name, input } => { + NodeToHostMessage::ToolCall { + session_id: message_session_id, + id, + name, + input, + } => { + if message_session_id != session_id { + return Err(format!( + "unexpected {PUBLIC_TOOL_NAME} runner tool call session id: {message_session_id}" + )); + } let response = HostToNodeMessage::Response { + session_id: message_session_id, id, code_mode_result: call_nested_tool(exec.clone(), name, input).await, }; - write_message(&mut stdin, &response).await?; + write_message(&mut process.stdin, &response).await?; } NodeToHostMessage::Result { + session_id: message_session_id, content_items, stored_values, error_text, max_output_tokens_per_exec_call, } => { + if message_session_id != session_id { + return Err(format!( + "unexpected {PUBLIC_TOOL_NAME} runner result session id: {message_session_id}" + )); + } exec.session .services .code_mode_store @@ -240,20 +325,11 @@ async fn execute_node( } } - drop(stdin); - - let status = child - .wait() - .await - .map_err(|err| format!("failed to wait for {PUBLIC_TOOL_NAME} runner: {err}"))?; - let stderr = stderr_task - .await - .map_err(|err| format!("failed to collect {PUBLIC_TOOL_NAME} stderr: {err}"))?; let wall_time = started_at.elapsed(); - let success = status.success(); - let Some((mut content_items, error_text, max_output_tokens_per_exec_call)) = pending_result else { + let status = process.wait_for_exit().await?; + let stderr = process.stderr().await?; let message = if stderr.is_empty() { format!("{PUBLIC_TOOL_NAME} runner exited without returning a result (status {status})") } else { @@ -262,14 +338,8 @@ async fn execute_node( return Err(message); }; - if !success { - let error_text = error_text.unwrap_or_else(|| { - if stderr.is_empty() { - format!("Process exited with status {status}") - } else { - stderr - } - }); + let success = error_text.is_none(); + if let Some(error_text) = error_text { content_items.push(FunctionCallOutputContentItem::InputText { text: format!("Script error:\n{error_text}"), }); diff --git a/codex-rs/core/src/tools/code_mode_runner.cjs b/codex-rs/core/src/tools/code_mode_runner.cjs index f36fa6f92ef..921564436e3 100644 --- a/codex-rs/core/src/tools/code_mode_runner.cjs +++ b/codex-rs/core/src/tools/code_mode_runner.cjs @@ -21,11 +21,10 @@ function createProtocol() { let nextId = 0; const pending = new Map(); - let initResolve; - let initReject; - const init = new Promise((resolve, reject) => { - initResolve = resolve; - initReject = reject; + const sessions = new Map(); + let closedResolve; + const closed = new Promise((resolve) => { + closedResolve = resolve; }); rl.on('line', (line) => { @@ -37,35 +36,38 @@ function createProtocol() { try { message = JSON.parse(line); } catch (error) { - initReject(error); + process.stderr.write(`${formatErrorText(error)}\n`); return; } - if (message.type === 'init') { - initResolve(message); + if (message.type === 'start') { + const session = { id: String(message.session_id) }; + sessions.set(session.id, session); + void processSession(protocol, sessions, session, message); return; } if (message.type === 'response') { - const entry = pending.get(message.id); + const entry = pending.get(`${message.session_id}:${message.id}`); if (!entry) { return; } - pending.delete(message.id); + pending.delete(`${message.session_id}:${message.id}`); entry.resolve(message.code_mode_result ?? ''); return; } - initReject(new Error(`Unknown protocol message type: ${message.type}`)); + process.stderr.write(`Unknown protocol message type: ${message.type}\n`); }); rl.on('close', () => { const error = new Error('stdin closed'); - initReject(error); for (const entry of pending.values()) { entry.reject(error); } pending.clear(); + sessions.clear(); + closedResolve(); }); function send(message) { @@ -80,18 +82,20 @@ function createProtocol() { }); } - function request(type, payload) { + function request(sessionId, type, payload) { const id = `msg-${++nextId}`; + const pendingKey = `${sessionId}:${id}`; return new Promise((resolve, reject) => { - pending.set(id, { resolve, reject }); - void send({ type, id, ...payload }).catch((error) => { - pending.delete(id); + pending.set(pendingKey, { resolve, reject }); + void send({ type, session_id: sessionId, id, ...payload }).catch((error) => { + pending.delete(pendingKey); reject(error); }); }); } - return { init, request, send }; + const protocol = { closed, request, send }; + return protocol; } function readContentItems(context) { @@ -112,9 +116,9 @@ function cloneJsonValue(value) { return JSON.parse(JSON.stringify(value)); } -function createToolCaller(protocol) { +function createToolCaller(protocol, sessionId) { return (name, input) => - protocol.request('tool_call', { + protocol.request(sessionId, 'tool_call', { name: String(name), input, }); @@ -348,14 +352,14 @@ function createModuleResolver(context, callTool, enabledTools, state) { }; } -async function runModule(context, request, state, callTool) { +async function runModule(context, start, state, callTool) { const resolveModule = createModuleResolver( context, callTool, - request.enabled_tools ?? [], + start.enabled_tools ?? [], state ); - const mainModule = new SourceTextModule(request.source, { + const mainModule = new SourceTextModule(start.source, { context, identifier: 'exec_main.mjs', importModuleDynamically: async (specifier) => resolveModule(specifier), @@ -365,40 +369,45 @@ async function runModule(context, request, state, callTool) { await mainModule.evaluate(); } -async function main() { - const protocol = createProtocol(); - const request = await protocol.init; +async function processSession(protocol, sessions, session, start) { const state = { maxOutputTokensPerExecCall: DEFAULT_MAX_OUTPUT_TOKENS_PER_EXEC_CALL, - storedValues: cloneJsonValue(request.stored_values ?? {}), + storedValues: cloneJsonValue(start.stored_values ?? {}), }; - const callTool = createToolCaller(protocol); + const callTool = createToolCaller(protocol, session.id); const context = vm.createContext({ __codexContentItems: [], __codex_tool_call: callTool, }); try { - await runModule(context, request, state, callTool); + await runModule(context, start, state, callTool); await protocol.send({ type: 'result', + session_id: session.id, content_items: readContentItems(context), stored_values: state.storedValues, max_output_tokens_per_exec_call: state.maxOutputTokensPerExecCall, }); - process.exit(0); } catch (error) { await protocol.send({ type: 'result', + session_id: session.id, content_items: readContentItems(context), stored_values: state.storedValues, error_text: formatErrorText(error), max_output_tokens_per_exec_call: state.maxOutputTokensPerExecCall, }); - process.exit(1); + } finally { + sessions.delete(session.id); } } +async function main() { + const protocol = createProtocol(); + await protocol.closed; +} + void main().catch(async (error) => { try { process.stderr.write(`${formatErrorText(error)}\n`); From e1cc893f9546a0bbe622ea872a528d5c29f851c4 Mon Sep 17 00:00:00 2001 From: pakrym-oai Date: Wed, 11 Mar 2026 17:00:51 -0700 Subject: [PATCH 2/7] Add long-running code-mode session waits --- codex-rs/core/src/state/service.rs | 34 +- codex-rs/core/src/tools/code_mode.rs | 551 +++++++++-- codex-rs/core/src/tools/code_mode_runner.cjs | 887 ++++++++++++------ codex-rs/core/src/tools/handlers/code_mode.rs | 79 +- codex-rs/core/src/tools/handlers/mod.rs | 1 + codex-rs/core/src/tools/spec.rs | 65 +- codex-rs/core/tests/suite/code_mode.rs | 794 ++++++++++++++++ 7 files changed, 2015 insertions(+), 396 deletions(-) diff --git a/codex-rs/core/src/state/service.rs b/codex-rs/core/src/state/service.rs index 0c436a921be..cc4791ae23b 100644 --- a/codex-rs/core/src/state/service.rs +++ b/codex-rs/core/src/state/service.rs @@ -16,6 +16,7 @@ use crate::plugins::PluginsManager; use crate::skills::SkillsManager; use crate::state_db::StateDbHandle; use crate::tools::code_mode::CodeModeProcess; +use crate::tools::code_mode::CodeModeYieldedSession; use crate::tools::network_approval::NetworkApprovalService; use crate::tools::runtimes::ExecveSessionApproval; use crate::tools::sandboxing::ApprovalStore; @@ -32,7 +33,9 @@ use tokio_util::sync::CancellationToken; pub(crate) struct CodeModeStoreService { stored_values: Mutex>, - process: Mutex>, + process: Mutex>>>, + yielded_sessions: Mutex>, + next_session_id: Mutex, } impl Default for CodeModeStoreService { @@ -40,6 +43,8 @@ impl Default for CodeModeStoreService { Self { stored_values: Mutex::new(HashMap::new()), process: Mutex::new(None), + yielded_sessions: Mutex::new(HashMap::new()), + next_session_id: Mutex::new(1), } } } @@ -53,12 +58,33 @@ impl CodeModeStoreService { *self.stored_values.lock().await = values; } - pub(crate) async fn store_process(&self, process: CodeModeProcess) { + pub(crate) async fn store_process(&self, process: Arc>) { *self.process.lock().await = Some(process); } - pub(crate) async fn take_process(&self) -> Option { - self.process.lock().await.take() + pub(crate) async fn process(&self) -> Option>> { + self.process.lock().await.clone() + } + + pub(crate) async fn allocate_session_id(&self) -> i32 { + let mut next_session_id = self.next_session_id.lock().await; + let session_id = *next_session_id; + *next_session_id = next_session_id.saturating_add(1); + session_id + } + + pub(crate) async fn store_yielded_session(&self, yielded_session: CodeModeYieldedSession) { + self.yielded_sessions + .lock() + .await + .insert(yielded_session.session_id, yielded_session); + } + + pub(crate) async fn take_yielded_session( + &self, + session_id: i32, + ) -> Option { + self.yielded_sessions.lock().await.remove(&session_id) } } diff --git a/codex-rs/core/src/tools/code_mode.rs b/codex-rs/core/src/tools/code_mode.rs index 955691c7844..945310cb844 100644 --- a/codex-rs/core/src/tools/code_mode.rs +++ b/codex-rs/core/src/tools/code_mode.rs @@ -1,4 +1,5 @@ use std::collections::HashMap; +use std::collections::VecDeque; use std::sync::Arc; use std::time::Duration; @@ -35,6 +36,8 @@ use tokio::task::JoinHandle; const CODE_MODE_RUNNER_SOURCE: &str = include_str!("code_mode_runner.cjs"); const CODE_MODE_BRIDGE_SOURCE: &str = include_str!("code_mode_bridge.js"); pub(crate) const PUBLIC_TOOL_NAME: &str = "exec"; +pub(crate) const WAIT_TOOL_NAME: &str = "exec_wait"; +pub(crate) const DEFAULT_WAIT_YIELD_TIME_MS: u64 = 10_000; #[derive(Clone)] struct ExecContext { @@ -48,6 +51,7 @@ pub(crate) struct CodeModeProcess { stdin: tokio::process::ChildStdin, stdout_lines: tokio::io::Lines>, stderr_task: Option>, + pending_messages: HashMap>, } impl CodeModeProcess { @@ -74,6 +78,11 @@ impl CodeModeProcess { } } +#[derive(Clone, Debug)] +pub(crate) struct CodeModeYieldedSession { + pub(crate) session_id: i32, +} + #[derive(Clone, Copy, Debug, Deserialize, Eq, PartialEq, Serialize)] #[serde(rename_all = "snake_case")] enum CodeModeToolKind { @@ -96,13 +105,20 @@ struct EnabledTool { #[serde(tag = "type", rename_all = "snake_case")] enum HostToNodeMessage { Start { - session_id: String, + session_id: i32, enabled_tools: Vec, stored_values: HashMap, source: String, }, + Poll { + session_id: i32, + yield_time_ms: u64, + }, + Terminate { + session_id: i32, + }, Response { - session_id: String, + session_id: i32, id: String, code_mode_result: JsonValue, }, @@ -112,14 +128,26 @@ enum HostToNodeMessage { #[serde(tag = "type", rename_all = "snake_case")] enum NodeToHostMessage { ToolCall { - session_id: String, + session_id: i32, id: String, name: String, #[serde(default)] input: Option, }, + Yielded { + session_id: i32, + content_items: Vec, + #[serde(default)] + max_output_tokens_per_exec_call: Option, + }, + Terminated { + session_id: i32, + content_items: Vec, + #[serde(default)] + max_output_tokens_per_exec_call: Option, + }, Result { - session_id: String, + session_id: i32, content_items: Vec, stored_values: HashMap, #[serde(default)] @@ -129,6 +157,36 @@ enum NodeToHostMessage { }, } +enum CodeModeSessionAction { + Start { + enabled_tools: Vec, + stored_values: HashMap, + source: String, + }, + Poll { + yield_time_ms: u64, + max_output_tokens: Option, + }, + Terminate { + max_output_tokens: Option, + }, +} + +enum CodeModeSessionProgress { + Finished(FunctionToolOutput), + Yielded { + output: FunctionToolOutput, + yielded_session: CodeModeYieldedSession, + }, +} + +enum CodeModeExecutionStatus { + Completed, + Failed, + Running(i32), + Terminated, +} + pub(crate) fn instructions(config: &Config) -> Option { if !config.features.enabled(Feature::CodeMode) { return None; @@ -149,7 +207,10 @@ pub(crate) fn instructions(config: &Config) -> Option { )); section.push_str("- Import nested tools from `tools.js`, for example `import { exec_command } from \"tools.js\"` or `import { ALL_TOOLS } from \"tools.js\"` to inspect the available `{ module, name, description }` entries. Namespaced tools are also available from `tools/.js`; MCP tools use `tools/mcp/.js`, for example `import { append_notebook_logs_chart } from \"tools/mcp/ologs.js\"`. Nested tool calls resolve to their code-mode result values.\n"); section.push_str(&format!( - "- Import `{{ output_text, output_image, set_max_output_tokens_per_exec_call, store, load }}` from `@openai/code_mode` (or `\"openai/code_mode\"`). `output_text(value)` surfaces text back to the model and stringifies non-string objects with `JSON.stringify(...)` when possible. `output_image(imageUrl)` appends an `input_image` content item for `http(s)` or `data:` URLs. `store(key, value)` persists JSON-serializable values across `{PUBLIC_TOOL_NAME}` calls in the current session, and `load(key)` returns a cloned stored value or `undefined`. `set_max_output_tokens_per_exec_call(value)` sets the token budget used to truncate the final Rust-side result of the current `{PUBLIC_TOOL_NAME}` execution; the default is `10000`. This guards the overall `{PUBLIC_TOOL_NAME}` output, not individual nested tool invocations. The returned content starts with a separate `Script completed` or `Script failed` text item that includes wall time. When truncation happens, the final text may include `Total output lines:` and the usual `…N tokens truncated…` marker.\n", + "- Import `{{ output_text, output_image, set_max_output_tokens_per_exec_call, set_yield_time, store, load }}` from `@openai/code_mode` (or `\"openai/code_mode\"`). `output_text(value)` surfaces text back to the model and stringifies non-string objects with `JSON.stringify(...)` when possible. `output_image(imageUrl)` appends an `input_image` content item for `http(s)` or `data:` URLs. `store(key, value)` persists JSON-serializable values across `{PUBLIC_TOOL_NAME}` calls in the current session, and `load(key)` returns a cloned stored value or `undefined`. `set_max_output_tokens_per_exec_call(value)` sets the token budget used to truncate direct `{PUBLIC_TOOL_NAME}` returns; `{WAIT_TOOL_NAME}` uses its own `max_tokens` argument instead and defaults to `10000`. `set_yield_time(value)` asks `{PUBLIC_TOOL_NAME}` to return early if the script is still running after that many milliseconds so `{WAIT_TOOL_NAME}` can resume it later. The returned content starts with a separate `Script completed`, `Script failed`, or `Script running with session ID …` text item that includes wall time. When truncation happens, the final text may include `Total output lines:` and the usual `…N tokens truncated…` marker.\n", + )); + section.push_str(&format!( + "- If `{PUBLIC_TOOL_NAME}` returns `Script running with session ID …`, call `{WAIT_TOOL_NAME}` with that `session_id` to keep waiting for more output, completion, or termination.\n", )); section.push_str( "- Function tools require JSON object arguments. Freeform tools require raw strings.\n", @@ -174,29 +235,106 @@ pub(crate) async fn execute( let enabled_tools = build_enabled_tools(&exec).await; let stored_values = exec.session.services.code_mode_store.stored_values().await; let source = build_source(&code, &enabled_tools).map_err(FunctionCallError::RespondToModel)?; - let mut process = match exec.session.services.code_mode_store.take_process().await { - Some(mut process) => { - if matches!(process.has_exited(), Ok(false)) { - process + let session_id = exec + .session + .services + .code_mode_store + .allocate_session_id() + .await; + let process = ensure_shared_code_mode_process(&exec) + .await + .map_err(FunctionCallError::RespondToModel)?; + let result = { + let mut process = process.lock().await; + drive_code_mode_session( + &exec, + &mut process, + session_id, + CodeModeSessionAction::Start { + enabled_tools, + stored_values, + source, + }, + ) + .await + }; + if let Ok(CodeModeSessionProgress::Yielded { + yielded_session, .. + }) = &result + { + exec.session + .services + .code_mode_store + .store_yielded_session(yielded_session.clone()) + .await; + } + match result { + Ok(CodeModeSessionProgress::Finished(output)) + | Ok(CodeModeSessionProgress::Yielded { output, .. }) => Ok(output), + Err(error) => Err(FunctionCallError::RespondToModel(error)), + } +} + +pub(crate) async fn wait( + session: Arc, + turn: Arc, + tracker: SharedTurnDiffTracker, + session_id: i32, + yield_time_ms: u64, + max_output_tokens: Option, + terminate: bool, +) -> Result { + let exec = ExecContext { + session, + turn, + tracker, + }; + let yielded_session = exec + .session + .services + .code_mode_store + .take_yielded_session(session_id) + .await + .ok_or_else(|| { + FunctionCallError::RespondToModel(format!( + "{WAIT_TOOL_NAME} session_id {session_id} is not waiting on {WAIT_TOOL_NAME}" + )) + })?; + + let process = existing_shared_code_mode_process(&exec).await?; + let result = { + let mut process = process.lock().await; + drive_code_mode_session( + &exec, + &mut process, + yielded_session.session_id, + if terminate { + CodeModeSessionAction::Terminate { max_output_tokens } } else { - spawn_code_mode_process(&exec) - .await - .map_err(FunctionCallError::RespondToModel)? - } - } - None => spawn_code_mode_process(&exec) - .await - .map_err(FunctionCallError::RespondToModel)?, + CodeModeSessionAction::Poll { + yield_time_ms, + max_output_tokens, + } + }, + ) + .await }; - let result = execute_node(&exec, &mut process, source, enabled_tools, stored_values).await; - if result.is_ok() && matches!(process.has_exited(), Ok(false)) { + if let Ok(CodeModeSessionProgress::Yielded { + yielded_session, .. + }) = &result + { exec.session .services .code_mode_store - .store_process(process) + .store_yielded_session(yielded_session.clone()) .await; } - result.map_err(FunctionCallError::RespondToModel) + + match result { + Ok(CodeModeSessionProgress::Finished(output)) + | Ok(CodeModeSessionProgress::Yielded { output, .. }) => Ok(output), + Err(error) => Err(FunctionCallError::RespondToModel(error)), + } } async fn spawn_code_mode_process(exec: &ExecContext) -> Result { @@ -242,31 +380,125 @@ async fn spawn_code_mode_process(exec: &ExecContext) -> Result Result>, String> { + if let Some(process) = exec.session.services.code_mode_store.process().await { + let is_running = { + let mut process_guard = process.lock().await; + matches!(process_guard.has_exited(), Ok(false)) + }; + if is_running { + return Ok(process); + } + } + + let process = Arc::new(tokio::sync::Mutex::new( + spawn_code_mode_process(exec).await?, + )); + exec.session + .services + .code_mode_store + .store_process(process.clone()) + .await; + Ok(process) +} + +async fn existing_shared_code_mode_process( + exec: &ExecContext, +) -> Result>, FunctionCallError> { + let process = exec + .session + .services + .code_mode_store + .process() + .await + .ok_or_else(|| { + FunctionCallError::RespondToModel(format!( + "{PUBLIC_TOOL_NAME} runner is not available for {WAIT_TOOL_NAME}" + )) + })?; + let is_running = { + let mut process_guard = process.lock().await; + matches!(process_guard.has_exited(), Ok(false)) + }; + if is_running { + Ok(process) + } else { + Err(FunctionCallError::RespondToModel(format!( + "{PUBLIC_TOOL_NAME} runner is not available for {WAIT_TOOL_NAME}" + ))) + } +} + +async fn drive_code_mode_session( exec: &ExecContext, process: &mut CodeModeProcess, - source: String, - enabled_tools: Vec, - stored_values: HashMap, -) -> Result { + session_id: i32, + action: CodeModeSessionAction, +) -> Result { let started_at = std::time::Instant::now(); - let session_id = uuid::Uuid::new_v4().to_string(); - - write_message( - &mut process.stdin, - &HostToNodeMessage::Start { - session_id: session_id.clone(), + let is_terminate = matches!(action, CodeModeSessionAction::Terminate { .. }); + let (message, poll_max_output_tokens) = match action { + CodeModeSessionAction::Start { enabled_tools, stored_values, source, - }, + } => ( + HostToNodeMessage::Start { + session_id, + enabled_tools, + stored_values, + source, + }, + None, + ), + CodeModeSessionAction::Poll { + yield_time_ms, + max_output_tokens, + } => ( + HostToNodeMessage::Poll { + session_id, + yield_time_ms, + }, + Some(max_output_tokens), + ), + CodeModeSessionAction::Terminate { max_output_tokens } => ( + HostToNodeMessage::Terminate { session_id }, + Some(max_output_tokens), + ), + }; + if let Some(progress) = process_pending_messages( + exec, + process, + session_id, + poll_max_output_tokens, + started_at, + is_terminate, ) - .await?; + .await? + { + return Ok(progress); + } + write_message(&mut process.stdin, &message).await?; + + if let Some(progress) = process_pending_messages( + exec, + process, + session_id, + poll_max_output_tokens, + started_at, + is_terminate, + ) + .await? + { + return Ok(progress); + } - let mut pending_result = None; while let Some(line) = process .stdout_lines .next_line() @@ -279,79 +511,203 @@ async fn execute_node( let message: NodeToHostMessage = serde_json::from_str(&line).map_err(|err| { format!("invalid {PUBLIC_TOOL_NAME} runner message: {err}; line={line}") })?; - match message { - NodeToHostMessage::ToolCall { + let message_session_id = message_session_id(&message); + if message_session_id != session_id { + if let NodeToHostMessage::ToolCall { session_id: message_session_id, id, name, input, - } => { - if message_session_id != session_id { - return Err(format!( - "unexpected {PUBLIC_TOOL_NAME} runner tool call session id: {message_session_id}" - )); - } + } = message + { let response = HostToNodeMessage::Response { session_id: message_session_id, id, code_mode_result: call_nested_tool(exec.clone(), name, input).await, }; write_message(&mut process.stdin, &response).await?; + } else { + process + .pending_messages + .entry(message_session_id) + .or_default() + .push_back(message); } - NodeToHostMessage::Result { - session_id: message_session_id, - content_items, - stored_values, - error_text, - max_output_tokens_per_exec_call, - } => { - if message_session_id != session_id { - return Err(format!( - "unexpected {PUBLIC_TOOL_NAME} runner result session id: {message_session_id}" - )); - } - exec.session - .services - .code_mode_store - .replace_stored_values(stored_values) - .await; - pending_result = Some(( - output_content_items_from_json_values(content_items)?, - error_text, - max_output_tokens_per_exec_call, - )); - break; - } + continue; + } + if let Some(progress) = handle_node_message( + exec, + process, + session_id, + message, + poll_max_output_tokens, + started_at, + is_terminate, + ) + .await? + { + return Ok(progress); } } - let wall_time = started_at.elapsed(); - let Some((mut content_items, error_text, max_output_tokens_per_exec_call)) = pending_result - else { - let status = process.wait_for_exit().await?; - let stderr = process.stderr().await?; - let message = if stderr.is_empty() { - format!("{PUBLIC_TOOL_NAME} runner exited without returning a result (status {status})") - } else { - stderr - }; - return Err(message); + let status = process.wait_for_exit().await?; + let stderr = process.stderr().await?; + let message = if stderr.is_empty() { + format!("{PUBLIC_TOOL_NAME} runner exited without returning a result (status {status})") + } else { + stderr }; + Err(message) +} - let success = error_text.is_none(); - if let Some(error_text) = error_text { - content_items.push(FunctionCallOutputContentItem::InputText { - text: format!("Script error:\n{error_text}"), - }); +async fn process_pending_messages( + exec: &ExecContext, + process: &mut CodeModeProcess, + session_id: i32, + poll_max_output_tokens: Option>, + started_at: std::time::Instant, + is_terminate: bool, +) -> Result, String> { + loop { + let Some(message) = process + .pending_messages + .get_mut(&session_id) + .and_then(VecDeque::pop_front) + else { + return Ok(None); + }; + if let Some(progress) = handle_node_message( + exec, + process, + session_id, + message, + poll_max_output_tokens, + started_at, + is_terminate, + ) + .await? + { + return Ok(Some(progress)); + } } +} - let mut content_items = - truncate_code_mode_result(content_items, max_output_tokens_per_exec_call); - prepend_script_status(&mut content_items, success, wall_time); - Ok(FunctionToolOutput::from_content( - content_items, - Some(success), - )) +async fn handle_node_message( + exec: &ExecContext, + process: &mut CodeModeProcess, + session_id: i32, + message: NodeToHostMessage, + poll_max_output_tokens: Option>, + started_at: std::time::Instant, + is_terminate: bool, +) -> Result, String> { + match message { + NodeToHostMessage::ToolCall { + session_id: message_session_id, + id, + name, + input, + } => { + if is_terminate { + return Ok(None); + } + let response = HostToNodeMessage::Response { + session_id: message_session_id, + id, + code_mode_result: call_nested_tool(exec.clone(), name, input).await, + }; + write_message(&mut process.stdin, &response).await?; + Ok(None) + } + NodeToHostMessage::Yielded { + content_items, + max_output_tokens_per_exec_call, + .. + } => { + if is_terminate { + return Ok(None); + } + let mut delta_items = output_content_items_from_json_values(content_items)?; + delta_items = truncate_code_mode_result( + delta_items, + poll_max_output_tokens.unwrap_or(max_output_tokens_per_exec_call), + ); + prepend_script_status( + &mut delta_items, + CodeModeExecutionStatus::Running(session_id), + started_at.elapsed(), + ); + Ok(Some(CodeModeSessionProgress::Yielded { + output: FunctionToolOutput::from_content(delta_items, Some(true)), + yielded_session: CodeModeYieldedSession { session_id }, + })) + } + NodeToHostMessage::Terminated { + content_items, + max_output_tokens_per_exec_call, + .. + } => { + let mut delta_items = output_content_items_from_json_values(content_items)?; + delta_items = truncate_code_mode_result( + delta_items, + poll_max_output_tokens.unwrap_or(max_output_tokens_per_exec_call), + ); + prepend_script_status( + &mut delta_items, + CodeModeExecutionStatus::Terminated, + started_at.elapsed(), + ); + Ok(Some(CodeModeSessionProgress::Finished( + FunctionToolOutput::from_content(delta_items, Some(true)), + ))) + } + NodeToHostMessage::Result { + content_items, + stored_values, + error_text, + max_output_tokens_per_exec_call, + .. + } => { + exec.session + .services + .code_mode_store + .replace_stored_values(stored_values) + .await; + let mut delta_items = output_content_items_from_json_values(content_items)?; + let success = error_text.is_none(); + if let Some(error_text) = error_text { + delta_items.push(FunctionCallOutputContentItem::InputText { + text: format!("Script error:\n{error_text}"), + }); + } + + let mut delta_items = truncate_code_mode_result( + delta_items, + poll_max_output_tokens.unwrap_or(max_output_tokens_per_exec_call), + ); + prepend_script_status( + &mut delta_items, + if success { + CodeModeExecutionStatus::Completed + } else { + CodeModeExecutionStatus::Failed + }, + started_at.elapsed(), + ); + Ok(Some(CodeModeSessionProgress::Finished( + FunctionToolOutput::from_content(delta_items, Some(success)), + ))) + } + } +} + +fn message_session_id(message: &NodeToHostMessage) -> i32 { + match message { + NodeToHostMessage::ToolCall { session_id, .. } + | NodeToHostMessage::Yielded { session_id, .. } + | NodeToHostMessage::Terminated { session_id, .. } + | NodeToHostMessage::Result { session_id, .. } => *session_id, + } } async fn write_message( @@ -376,16 +732,19 @@ async fn write_message( fn prepend_script_status( content_items: &mut Vec, - success: bool, + status: CodeModeExecutionStatus, wall_time: Duration, ) { let wall_time_seconds = ((wall_time.as_secs_f32()) * 10.0).round() / 10.0; let header = format!( "{}\nWall time {wall_time_seconds:.1} seconds\nOutput:\n", - if success { - "Script completed" - } else { - "Script failed" + match status { + CodeModeExecutionStatus::Completed => "Script completed".to_string(), + CodeModeExecutionStatus::Failed => "Script failed".to_string(), + CodeModeExecutionStatus::Running(session_id) => { + format!("Script running with session ID {session_id}") + } + CodeModeExecutionStatus::Terminated => "Script terminated".to_string(), } ); content_items.insert(0, FunctionCallOutputContentItem::InputText { text: header }); @@ -435,7 +794,7 @@ async fn build_enabled_tools(exec: &ExecContext) -> Vec { fn enabled_tool_from_spec(spec: ToolSpec) -> Option { let tool_name = spec.name().to_string(); - if tool_name == PUBLIC_TOOL_NAME { + if tool_name == PUBLIC_TOOL_NAME || tool_name == WAIT_TOOL_NAME { return None; } diff --git a/codex-rs/core/src/tools/code_mode_runner.cjs b/codex-rs/core/src/tools/code_mode_runner.cjs index 921564436e3..5bad1c61992 100644 --- a/codex-rs/core/src/tools/code_mode_runner.cjs +++ b/codex-rs/core/src/tools/code_mode_runner.cjs @@ -1,9 +1,8 @@ 'use strict'; const readline = require('node:readline'); -const vm = require('node:vm'); +const { Worker } = require('node:worker_threads'); -const { SourceTextModule, SyntheticModule } = vm; const DEFAULT_MAX_OUTPUT_TOKENS_PER_EXEC_CALL = 10000; function normalizeMaxOutputTokensPerExecCall(value) { @@ -13,6 +12,425 @@ function normalizeMaxOutputTokensPerExecCall(value) { return value; } +function normalizeYieldTime(value) { + if (!Number.isSafeInteger(value) || value < 0) { + throw new TypeError('yield_time must be a non-negative safe integer'); + } + return value; +} + +function formatErrorText(error) { + return String(error && error.stack ? error.stack : error); +} + +function cloneJsonValue(value) { + return JSON.parse(JSON.stringify(value)); +} + +function clearTimer(timer) { + if (timer !== null) { + clearTimeout(timer); + } + return null; +} + +function takeContentItems(session) { + const clonedContentItems = cloneJsonValue(session.content_items); + session.content_items.splice(0, session.content_items.length); + return Array.isArray(clonedContentItems) ? clonedContentItems : []; +} + +function codeModeWorkerMain() { + 'use strict'; + + const { parentPort, workerData } = require('node:worker_threads'); + const vm = require('node:vm'); + const { SourceTextModule, SyntheticModule } = vm; + + const DEFAULT_MAX_OUTPUT_TOKENS_PER_EXEC_CALL = 10000; + + function normalizeMaxOutputTokensPerExecCall(value) { + if (!Number.isSafeInteger(value) || value < 0) { + throw new TypeError('max_output_tokens_per_exec_call must be a non-negative safe integer'); + } + return value; + } + + function normalizeYieldTime(value) { + if (!Number.isSafeInteger(value) || value < 0) { + throw new TypeError('yield_time must be a non-negative safe integer'); + } + return value; + } + + function formatErrorText(error) { + return String(error && error.stack ? error.stack : error); + } + + function cloneJsonValue(value) { + return JSON.parse(JSON.stringify(value)); + } + + function createToolCaller() { + let nextId = 0; + const pending = new Map(); + + parentPort.on('message', (message) => { + if (message.type === 'tool_response') { + const entry = pending.get(message.id); + if (!entry) { + return; + } + pending.delete(message.id); + entry.resolve(message.result ?? ''); + return; + } + + if (message.type === 'tool_response_error') { + const entry = pending.get(message.id); + if (!entry) { + return; + } + pending.delete(message.id); + entry.reject(new Error(message.error_text ?? 'tool call failed')); + return; + } + }); + + return (name, input) => { + const id = 'msg-' + ++nextId; + return new Promise((resolve, reject) => { + pending.set(id, { resolve, reject }); + parentPort.postMessage({ + type: 'tool_call', + id, + name: String(name), + input, + }); + }); + }; + } + + function createContentItems() { + const contentItems = []; + const push = contentItems.push.bind(contentItems); + contentItems.push = (...items) => { + for (const item of items) { + parentPort.postMessage({ + type: 'content_item', + item: cloneJsonValue(item), + }); + } + return push(...items); + }; + parentPort.on('message', (message) => { + if (message.type === 'clear_content') { + contentItems.splice(0, contentItems.length); + } + }); + return contentItems; + } + + function createToolsNamespace(callTool, enabledTools) { + const tools = Object.create(null); + + for (const { tool_name } of enabledTools) { + Object.defineProperty(tools, tool_name, { + value: async (args) => callTool(tool_name, args), + configurable: false, + enumerable: true, + writable: false, + }); + } + + return Object.freeze(tools); + } + + function createAllToolsMetadata(enabledTools) { + return Object.freeze( + enabledTools.map(({ module: modulePath, name, description }) => + Object.freeze({ + module: modulePath, + name, + description, + }) + ) + ); + } + + function createToolsModule(context, callTool, enabledTools) { + const tools = createToolsNamespace(callTool, enabledTools); + const allTools = createAllToolsMetadata(enabledTools); + const exportNames = ['ALL_TOOLS']; + + for (const { tool_name } of enabledTools) { + if (tool_name !== 'ALL_TOOLS') { + exportNames.push(tool_name); + } + } + + const uniqueExportNames = [...new Set(exportNames)]; + + return new SyntheticModule( + uniqueExportNames, + function initToolsModule() { + this.setExport('ALL_TOOLS', allTools); + for (const exportName of uniqueExportNames) { + if (exportName !== 'ALL_TOOLS') { + this.setExport(exportName, tools[exportName]); + } + } + }, + { context } + ); + } + + function ensureContentItems(context) { + if (!Array.isArray(context.__codexContentItems)) { + context.__codexContentItems = []; + } + return context.__codexContentItems; + } + + function serializeOutputText(value) { + if (typeof value === 'string') { + return value; + } + if ( + typeof value === 'undefined' || + value === null || + typeof value === 'boolean' || + typeof value === 'number' || + typeof value === 'bigint' + ) { + return String(value); + } + + const serialized = JSON.stringify(value); + if (typeof serialized === 'string') { + return serialized; + } + + return String(value); + } + + function normalizeOutputImageUrl(value) { + if (typeof value !== 'string' || !value) { + throw new TypeError('output_image expects a non-empty image URL string'); + } + if (/^(?:https?:\/\/|data:)/i.test(value)) { + return value; + } + throw new TypeError('output_image expects an http(s) or data URL'); + } + + function createCodeModeModule(context, state) { + const load = (key) => { + if (typeof key !== 'string') { + throw new TypeError('load key must be a string'); + } + if (!Object.prototype.hasOwnProperty.call(state.storedValues, key)) { + return undefined; + } + return cloneJsonValue(state.storedValues[key]); + }; + const store = (key, value) => { + if (typeof key !== 'string') { + throw new TypeError('store key must be a string'); + } + state.storedValues[key] = cloneJsonValue(value); + }; + const outputText = (value) => { + const item = { + type: 'input_text', + text: serializeOutputText(value), + }; + ensureContentItems(context).push(item); + return item; + }; + const outputImage = (value) => { + const item = { + type: 'input_image', + image_url: normalizeOutputImageUrl(value), + }; + ensureContentItems(context).push(item); + return item; + }; + + return new SyntheticModule( + [ + 'load', + 'output_text', + 'output_image', + 'set_max_output_tokens_per_exec_call', + 'set_yield_time', + 'store', + ], + function initCodeModeModule() { + this.setExport('load', load); + this.setExport('output_text', outputText); + this.setExport('output_image', outputImage); + this.setExport('set_max_output_tokens_per_exec_call', (value) => { + const normalized = normalizeMaxOutputTokensPerExecCall(value); + state.maxOutputTokensPerExecCall = normalized; + parentPort.postMessage({ + type: 'set_max_output_tokens_per_exec_call', + value: normalized, + }); + return normalized; + }); + this.setExport('set_yield_time', (value) => { + const normalized = normalizeYieldTime(value); + parentPort.postMessage({ + type: 'set_yield_time', + value: normalized, + }); + return normalized; + }); + this.setExport('store', store); + }, + { context } + ); + } + + function namespacesMatch(left, right) { + if (left.length !== right.length) { + return false; + } + return left.every((segment, index) => segment === right[index]); + } + + function createNamespacedToolsNamespace(callTool, enabledTools, namespace) { + const tools = Object.create(null); + + for (const tool of enabledTools) { + const toolNamespace = Array.isArray(tool.namespace) ? tool.namespace : []; + if (!namespacesMatch(toolNamespace, namespace)) { + continue; + } + + Object.defineProperty(tools, tool.name, { + value: async (args) => callTool(tool.tool_name, args), + configurable: false, + enumerable: true, + writable: false, + }); + } + + return Object.freeze(tools); + } + + function createNamespacedToolsModule(context, callTool, enabledTools, namespace) { + const tools = createNamespacedToolsNamespace(callTool, enabledTools, namespace); + const exportNames = []; + + for (const exportName of Object.keys(tools)) { + if (exportName !== 'ALL_TOOLS') { + exportNames.push(exportName); + } + } + + const uniqueExportNames = [...new Set(exportNames)]; + + return new SyntheticModule( + uniqueExportNames, + function initNamespacedToolsModule() { + for (const exportName of uniqueExportNames) { + this.setExport(exportName, tools[exportName]); + } + }, + { context } + ); + } + + function createModuleResolver(context, callTool, enabledTools, state) { + const toolsModule = createToolsModule(context, callTool, enabledTools); + const codeModeModule = createCodeModeModule(context, state); + const namespacedModules = new Map(); + + return function resolveModule(specifier) { + if (specifier === 'tools.js') { + return toolsModule; + } + if (specifier === '@openai/code_mode' || specifier === 'openai/code_mode') { + return codeModeModule; + } + const namespacedMatch = /^tools\/(.+)\.js$/.exec(specifier); + if (!namespacedMatch) { + throw new Error('Unsupported import in exec: ' + specifier); + } + + const namespace = namespacedMatch[1] + .split('/') + .filter((segment) => segment.length > 0); + if (namespace.length === 0) { + throw new Error('Unsupported import in exec: ' + specifier); + } + + const cacheKey = namespace.join('/'); + if (!namespacedModules.has(cacheKey)) { + namespacedModules.set( + cacheKey, + createNamespacedToolsModule(context, callTool, enabledTools, namespace) + ); + } + return namespacedModules.get(cacheKey); + }; + } + + async function runModule(context, start, state, callTool) { + const resolveModule = createModuleResolver( + context, + callTool, + start.enabled_tools ?? [], + state + ); + const mainModule = new SourceTextModule(start.source, { + context, + identifier: 'exec_main.mjs', + importModuleDynamically: async (specifier) => resolveModule(specifier), + }); + + await mainModule.link(resolveModule); + await mainModule.evaluate(); + } + + async function main() { + const start = workerData ?? {}; + const state = { + maxOutputTokensPerExecCall: DEFAULT_MAX_OUTPUT_TOKENS_PER_EXEC_CALL, + storedValues: cloneJsonValue(start.stored_values ?? {}), + }; + const callTool = createToolCaller(); + const context = vm.createContext({ + __codexContentItems: createContentItems(), + __codex_tool_call: callTool, + }); + + try { + await runModule(context, start, state, callTool); + parentPort.postMessage({ + type: 'result', + stored_values: state.storedValues, + }); + } catch (error) { + parentPort.postMessage({ + type: 'result', + stored_values: state.storedValues, + error_text: formatErrorText(error), + }); + } + } + + void main().catch((error) => { + parentPort.postMessage({ + type: 'result', + stored_values: {}, + error_text: formatErrorText(error), + }); + }); +} + function createProtocol() { const rl = readline.createInterface({ input: process.stdin, @@ -36,28 +454,42 @@ function createProtocol() { try { message = JSON.parse(line); } catch (error) { - process.stderr.write(`${formatErrorText(error)}\n`); + process.stderr.write(formatErrorText(error) + '\n'); return; } if (message.type === 'start') { - const session = { id: String(message.session_id) }; - sessions.set(session.id, session); - void processSession(protocol, sessions, session, message); + startSession(protocol, sessions, message); + return; + } + + if (message.type === 'poll') { + const session = sessions.get(message.session_id); + if (session) { + schedulePollYield(protocol, session, normalizeYieldTime(message.yield_time_ms ?? 0)); + } + return; + } + + if (message.type === 'terminate') { + const session = sessions.get(message.session_id); + if (session) { + void terminateSession(protocol, sessions, session); + } return; } if (message.type === 'response') { - const entry = pending.get(`${message.session_id}:${message.id}`); + const entry = pending.get(message.session_id + ':' + message.id); if (!entry) { return; } - pending.delete(`${message.session_id}:${message.id}`); + pending.delete(message.session_id + ':' + message.id); entry.resolve(message.code_mode_result ?? ''); return; } - process.stderr.write(`Unknown protocol message type: ${message.type}\n`); + process.stderr.write('Unknown protocol message type: ' + message.type + '\n'); }); rl.on('close', () => { @@ -66,13 +498,18 @@ function createProtocol() { entry.reject(error); } pending.clear(); + for (const session of sessions.values()) { + session.initial_yield_timer = clearTimer(session.initial_yield_timer); + session.poll_yield_timer = clearTimer(session.poll_yield_timer); + void session.worker.terminate().catch(() => {}); + } sessions.clear(); closedResolve(); }); function send(message) { return new Promise((resolve, reject) => { - process.stdout.write(`${JSON.stringify(message)}\n`, (error) => { + process.stdout.write(JSON.stringify(message) + '\n', (error) => { if (error) { reject(error); } else { @@ -83,8 +520,8 @@ function createProtocol() { } function request(sessionId, type, payload) { - const id = `msg-${++nextId}`; - const pendingKey = `${sessionId}:${id}`; + const id = 'msg-' + ++nextId; + const pendingKey = sessionId + ':' + id; return new Promise((resolve, reject) => { pending.set(pendingKey, { resolve, reject }); void send({ type, session_id: sessionId, id, ...payload }).catch((error) => { @@ -98,309 +535,199 @@ function createProtocol() { return protocol; } -function readContentItems(context) { - try { - const serialized = vm.runInContext('JSON.stringify(globalThis.__codexContentItems ?? [])', context); - const contentItems = JSON.parse(serialized); - return Array.isArray(contentItems) ? contentItems : []; - } catch { - return []; - } -} - -function formatErrorText(error) { - return String(error && error.stack ? error.stack : error); -} - -function cloneJsonValue(value) { - return JSON.parse(JSON.stringify(value)); +function sessionWorkerSource() { + return '(' + codeModeWorkerMain.toString() + ')();'; } -function createToolCaller(protocol, sessionId) { - return (name, input) => - protocol.request(sessionId, 'tool_call', { - name: String(name), - input, +function startSession(protocol, sessions, start) { + const session = { + completed: false, + content_items: [], + id: start.session_id, + initial_yield_timer: null, + initial_yield_triggered: false, + max_output_tokens_per_exec_call: DEFAULT_MAX_OUTPUT_TOKENS_PER_EXEC_CALL, + poll_yield_timer: null, + worker: new Worker(sessionWorkerSource(), { + eval: true, + workerData: start, + }), + }; + sessions.set(session.id, session); + + session.worker.on('message', (message) => { + void handleWorkerMessage(protocol, sessions, session, message).catch((error) => { + void completeSession(protocol, sessions, session, { + type: 'result', + stored_values: {}, + error_text: formatErrorText(error), + }); }); -} - -function createToolsNamespace(callTool, enabledTools) { - const tools = Object.create(null); - - for (const { tool_name } of enabledTools) { - Object.defineProperty(tools, tool_name, { - value: async (args) => callTool(tool_name, args), - configurable: false, - enumerable: true, - writable: false, + }); + session.worker.on('error', (error) => { + void completeSession(protocol, sessions, session, { + type: 'result', + stored_values: {}, + error_text: formatErrorText(error), }); - } - - return Object.freeze(tools); -} - -function createAllToolsMetadata(enabledTools) { - return Object.freeze( - enabledTools.map(({ module: modulePath, name, description }) => - Object.freeze({ - module: modulePath, - name, - description, - }) - ) - ); -} - -function createToolsModule(context, callTool, enabledTools) { - const tools = createToolsNamespace(callTool, enabledTools); - const allTools = createAllToolsMetadata(enabledTools); - const exportNames = ['ALL_TOOLS']; - - for (const { tool_name } of enabledTools) { - if (tool_name !== 'ALL_TOOLS') { - exportNames.push(tool_name); + }); + session.worker.on('exit', (code) => { + if (code !== 0 && !session.completed) { + void completeSession(protocol, sessions, session, { + type: 'result', + stored_values: {}, + error_text: 'exec worker exited with code ' + code, + }); } - } - - const uniqueExportNames = [...new Set(exportNames)]; - - return new SyntheticModule( - uniqueExportNames, - function initToolsModule() { - this.setExport('ALL_TOOLS', allTools); - for (const exportName of uniqueExportNames) { - if (exportName !== 'ALL_TOOLS') { - this.setExport(exportName, tools[exportName]); - } - } - }, - { context } - ); + }); } -function ensureContentItems(context) { - if (!Array.isArray(context.__codexContentItems)) { - context.__codexContentItems = []; +async function handleWorkerMessage(protocol, sessions, session, message) { + if (session.completed) { + return; } - return context.__codexContentItems; -} -function serializeOutputText(value) { - if (typeof value === 'string') { - return value; - } - if ( - typeof value === 'undefined' || - value === null || - typeof value === 'boolean' || - typeof value === 'number' || - typeof value === 'bigint' - ) { - return String(value); + if (message.type === 'content_item') { + session.content_items.push(cloneJsonValue(message.item)); + return; } - const serialized = JSON.stringify(value); - if (typeof serialized === 'string') { - return serialized; + if (message.type === 'set_yield_time') { + scheduleInitialYield(protocol, session, normalizeYieldTime(message.value ?? 0)); + return; } - return String(value); -} + if (message.type === 'set_max_output_tokens_per_exec_call') { + session.max_output_tokens_per_exec_call = normalizeMaxOutputTokensPerExecCall(message.value); + return; + } -function normalizeOutputImageUrl(value) { - if (typeof value !== 'string' || !value) { - throw new TypeError('output_image expects a non-empty image URL string'); + if (message.type === 'tool_call') { + void forwardToolCall(protocol, session, message); + return; } - if (/^(?:https?:\/\/|data:)/i.test(value)) { - return value; + + if (message.type === 'result') { + await completeSession(protocol, sessions, session, { + type: 'result', + stored_values: cloneJsonValue(message.stored_values ?? {}), + error_text: + typeof message.error_text === 'string' ? message.error_text : undefined, + }); + return; } - throw new TypeError('output_image expects an http(s) or data URL'); + + process.stderr.write('Unknown worker message type: ' + message.type + '\n'); } -function createCodeModeModule(context, state) { - const load = (key) => { - if (typeof key !== 'string') { - throw new TypeError('load key must be a string'); - } - if (!Object.prototype.hasOwnProperty.call(state.storedValues, key)) { - return undefined; +async function forwardToolCall(protocol, session, message) { + try { + const result = await protocol.request(session.id, 'tool_call', { + name: String(message.name), + input: message.input, + }); + if (session.completed) { + return; } - return cloneJsonValue(state.storedValues[key]); - }; - const store = (key, value) => { - if (typeof key !== 'string') { - throw new TypeError('store key must be a string'); + try { + session.worker.postMessage({ + type: 'tool_response', + id: message.id, + result, + }); + } catch {} + } catch (error) { + if (session.completed) { + return; } - state.storedValues[key] = cloneJsonValue(value); - }; - const outputText = (value) => { - const item = { - type: 'input_text', - text: serializeOutputText(value), - }; - ensureContentItems(context).push(item); - return item; - }; - const outputImage = (value) => { - const item = { - type: 'input_image', - image_url: normalizeOutputImageUrl(value), - }; - ensureContentItems(context).push(item); - return item; - }; - - return new SyntheticModule( - ['load', 'output_text', 'output_image', 'set_max_output_tokens_per_exec_call', 'store'], - function initCodeModeModule() { - this.setExport('load', load); - this.setExport('output_text', outputText); - this.setExport('output_image', outputImage); - this.setExport('set_max_output_tokens_per_exec_call', (value) => { - const normalized = normalizeMaxOutputTokensPerExecCall(value); - state.maxOutputTokensPerExecCall = normalized; - return normalized; + try { + session.worker.postMessage({ + type: 'tool_response_error', + id: message.id, + error_text: formatErrorText(error), }); - this.setExport('store', store); - }, - { context } - ); -} - -function namespacesMatch(left, right) { - if (left.length !== right.length) { - return false; + } catch {} } - return left.every((segment, index) => segment === right[index]); } -function createNamespacedToolsNamespace(callTool, enabledTools, namespace) { - const tools = Object.create(null); - - for (const tool of enabledTools) { - const toolNamespace = Array.isArray(tool.namespace) ? tool.namespace : []; - if (!namespacesMatch(toolNamespace, namespace)) { - continue; - } - - Object.defineProperty(tools, tool.name, { - value: async (args) => callTool(tool.tool_name, args), - configurable: false, - enumerable: true, - writable: false, - }); +async function sendYielded(protocol, session) { + if (session.completed) { + return; } - - return Object.freeze(tools); + const contentItems = takeContentItems(session); + try { + session.worker.postMessage({ type: 'clear_content' }); + } catch {} + await protocol.send({ + type: 'yielded', + session_id: session.id, + content_items: contentItems, + max_output_tokens_per_exec_call: session.max_output_tokens_per_exec_call, + }); } -function createNamespacedToolsModule(context, callTool, enabledTools, namespace) { - const tools = createNamespacedToolsNamespace(callTool, enabledTools, namespace); - const exportNames = []; - - for (const exportName of Object.keys(tools)) { - if (exportName !== 'ALL_TOOLS') { - exportNames.push(exportName); - } +function scheduleInitialYield(protocol, session, yieldTime) { + if (session.completed || session.initial_yield_triggered) { + return yieldTime; } - - const uniqueExportNames = [...new Set(exportNames)]; - - return new SyntheticModule( - uniqueExportNames, - function initNamespacedToolsModule() { - for (const exportName of uniqueExportNames) { - this.setExport(exportName, tools[exportName]); - } - }, - { context } - ); + session.initial_yield_timer = clearTimer(session.initial_yield_timer); + session.initial_yield_timer = setTimeout(() => { + session.initial_yield_timer = null; + session.initial_yield_triggered = true; + void sendYielded(protocol, session); + }, yieldTime); + return yieldTime; } -function createModuleResolver(context, callTool, enabledTools, state) { - const toolsModule = createToolsModule(context, callTool, enabledTools); - const codeModeModule = createCodeModeModule(context, state); - const namespacedModules = new Map(); - - return function resolveModule(specifier) { - if (specifier === 'tools.js') { - return toolsModule; - } - if (specifier === '@openai/code_mode' || specifier === 'openai/code_mode') { - return codeModeModule; - } - const namespacedMatch = /^tools\/(.+)\.js$/.exec(specifier); - if (!namespacedMatch) { - throw new Error(`Unsupported import in exec: ${specifier}`); - } - - const namespace = namespacedMatch[1] - .split('/') - .filter((segment) => segment.length > 0); - if (namespace.length === 0) { - throw new Error(`Unsupported import in exec: ${specifier}`); - } - - const cacheKey = namespace.join('/'); - if (!namespacedModules.has(cacheKey)) { - namespacedModules.set( - cacheKey, - createNamespacedToolsModule(context, callTool, enabledTools, namespace) - ); - } - return namespacedModules.get(cacheKey); - }; +function schedulePollYield(protocol, session, yieldTime) { + if (session.completed) { + return; + } + session.poll_yield_timer = clearTimer(session.poll_yield_timer); + session.poll_yield_timer = setTimeout(() => { + session.poll_yield_timer = null; + void sendYielded(protocol, session); + }, yieldTime); } -async function runModule(context, start, state, callTool) { - const resolveModule = createModuleResolver( - context, - callTool, - start.enabled_tools ?? [], - state - ); - const mainModule = new SourceTextModule(start.source, { - context, - identifier: 'exec_main.mjs', - importModuleDynamically: async (specifier) => resolveModule(specifier), +async function completeSession(protocol, sessions, session, message) { + if (session.completed) { + return; + } + session.completed = true; + session.initial_yield_timer = clearTimer(session.initial_yield_timer); + session.poll_yield_timer = clearTimer(session.poll_yield_timer); + sessions.delete(session.id); + const contentItems = takeContentItems(session); + try { + session.worker.postMessage({ type: 'clear_content' }); + } catch {} + await protocol.send({ + ...message, + session_id: session.id, + content_items: contentItems, + max_output_tokens_per_exec_call: session.max_output_tokens_per_exec_call, }); - - await mainModule.link(resolveModule); - await mainModule.evaluate(); } -async function processSession(protocol, sessions, session, start) { - const state = { - maxOutputTokensPerExecCall: DEFAULT_MAX_OUTPUT_TOKENS_PER_EXEC_CALL, - storedValues: cloneJsonValue(start.stored_values ?? {}), - }; - const callTool = createToolCaller(protocol, session.id); - const context = vm.createContext({ - __codexContentItems: [], - __codex_tool_call: callTool, - }); - - try { - await runModule(context, start, state, callTool); - await protocol.send({ - type: 'result', - session_id: session.id, - content_items: readContentItems(context), - stored_values: state.storedValues, - max_output_tokens_per_exec_call: state.maxOutputTokensPerExecCall, - }); - } catch (error) { - await protocol.send({ - type: 'result', - session_id: session.id, - content_items: readContentItems(context), - stored_values: state.storedValues, - error_text: formatErrorText(error), - max_output_tokens_per_exec_call: state.maxOutputTokensPerExecCall, - }); - } finally { - sessions.delete(session.id); +async function terminateSession(protocol, sessions, session) { + if (session.completed) { + return; } + session.completed = true; + session.initial_yield_timer = clearTimer(session.initial_yield_timer); + session.poll_yield_timer = clearTimer(session.poll_yield_timer); + sessions.delete(session.id); + const contentItems = takeContentItems(session); + try { + await session.worker.terminate(); + } catch {} + await protocol.send({ + type: 'terminated', + session_id: session.id, + content_items: contentItems, + max_output_tokens_per_exec_call: session.max_output_tokens_per_exec_call, + }); } async function main() { @@ -410,7 +737,7 @@ async function main() { void main().catch(async (error) => { try { - process.stderr.write(`${formatErrorText(error)}\n`); + process.stderr.write(formatErrorText(error) + '\n'); } finally { process.exitCode = 1; } diff --git a/codex-rs/core/src/tools/handlers/code_mode.rs b/codex-rs/core/src/tools/handlers/code_mode.rs index 4763a69b46f..fe4a23965dd 100644 --- a/codex-rs/core/src/tools/handlers/code_mode.rs +++ b/codex-rs/core/src/tools/handlers/code_mode.rs @@ -1,16 +1,35 @@ use async_trait::async_trait; +use serde::Deserialize; -use crate::features::Feature; use crate::function_tool::FunctionCallError; use crate::tools::code_mode; +use crate::tools::code_mode::DEFAULT_WAIT_YIELD_TIME_MS; use crate::tools::code_mode::PUBLIC_TOOL_NAME; +use crate::tools::code_mode::WAIT_TOOL_NAME; use crate::tools::context::FunctionToolOutput; use crate::tools::context::ToolInvocation; use crate::tools::context::ToolPayload; +use crate::tools::handlers::parse_arguments; use crate::tools::registry::ToolHandler; use crate::tools::registry::ToolKind; pub struct CodeModeHandler; +pub struct CodeModeWaitHandler; + +#[derive(Debug, Deserialize)] +struct ExecWaitArgs { + session_id: i32, + #[serde(default = "default_wait_yield_time_ms")] + yield_time_ms: u64, + #[serde(default)] + max_tokens: Option, + #[serde(default)] + terminate: bool, +} + +fn default_wait_yield_time_ms() -> u64 { + DEFAULT_WAIT_YIELD_TIME_MS +} #[async_trait] impl ToolHandler for CodeModeHandler { @@ -29,25 +48,57 @@ impl ToolHandler for CodeModeHandler { session, turn, tracker, + tool_name, payload, .. } = invocation; - if !session.features().enabled(Feature::CodeMode) { - return Err(FunctionCallError::RespondToModel(format!( - "{PUBLIC_TOOL_NAME} is disabled by feature flag" - ))); + match payload { + ToolPayload::Custom { input } if tool_name == PUBLIC_TOOL_NAME => { + code_mode::execute(session, turn, tracker, input).await + } + _ => Err(FunctionCallError::RespondToModel(format!( + "{PUBLIC_TOOL_NAME} expects raw JavaScript source text" + ))), } + } +} - let code = match payload { - ToolPayload::Custom { input } => input, - _ => { - return Err(FunctionCallError::RespondToModel(format!( - "{PUBLIC_TOOL_NAME} expects raw JavaScript source text" - ))); - } - }; +#[async_trait] +impl ToolHandler for CodeModeWaitHandler { + type Output = FunctionToolOutput; - code_mode::execute(session, turn, tracker, code).await + fn kind(&self) -> ToolKind { + ToolKind::Function + } + + async fn handle(&self, invocation: ToolInvocation) -> Result { + let ToolInvocation { + session, + turn, + tracker, + tool_name, + payload, + .. + } = invocation; + + match payload { + ToolPayload::Function { arguments } if tool_name == WAIT_TOOL_NAME => { + let args: ExecWaitArgs = parse_arguments(&arguments)?; + code_mode::wait( + session, + turn, + tracker, + args.session_id, + args.yield_time_ms, + args.max_tokens, + args.terminate, + ) + .await + } + _ => Err(FunctionCallError::RespondToModel(format!( + "{WAIT_TOOL_NAME} expects JSON arguments" + ))), + } } } diff --git a/codex-rs/core/src/tools/handlers/mod.rs b/codex-rs/core/src/tools/handlers/mod.rs index 38d0f74f4ca..def280f03bb 100644 --- a/codex-rs/core/src/tools/handlers/mod.rs +++ b/codex-rs/core/src/tools/handlers/mod.rs @@ -34,6 +34,7 @@ use crate::sandboxing::normalize_additional_permissions; pub use apply_patch::ApplyPatchHandler; pub use artifacts::ArtifactsHandler; pub use code_mode::CodeModeHandler; +pub use code_mode::CodeModeWaitHandler; use codex_protocol::models::PermissionProfile; use codex_protocol::protocol::AskForApproval; pub use dynamic::DynamicToolHandler; diff --git a/codex-rs/core/src/tools/spec.rs b/codex-rs/core/src/tools/spec.rs index cf9eaa32f75..18528ba9f96 100644 --- a/codex-rs/core/src/tools/spec.rs +++ b/codex-rs/core/src/tools/spec.rs @@ -7,7 +7,9 @@ use crate::features::Feature; use crate::features::Features; use crate::mcp_connection_manager::ToolInfo; use crate::models_manager::collaboration_mode_presets::CollaborationModesConfig; +use crate::tools::code_mode::DEFAULT_WAIT_YIELD_TIME_MS; use crate::tools::code_mode::PUBLIC_TOOL_NAME; +use crate::tools::code_mode::WAIT_TOOL_NAME; use crate::tools::code_mode_description::augment_tool_spec_for_code_mode; use crate::tools::handlers::PLAN_TOOL; use crate::tools::handlers::SEARCH_TOOL_BM25_DEFAULT_LIMIT; @@ -572,6 +574,54 @@ fn create_write_stdin_tool() -> ToolSpec { }) } +fn create_exec_wait_tool() -> ToolSpec { + let properties = BTreeMap::from([ + ( + "session_id".to_string(), + JsonSchema::Number { + description: Some("Identifier of the running exec session.".to_string()), + }, + ), + ( + "yield_time_ms".to_string(), + JsonSchema::Number { + description: Some( + "How long to wait (in milliseconds) for more output before yielding again." + .to_string(), + ), + }, + ), + ( + "max_tokens".to_string(), + JsonSchema::Number { + description: Some( + "Maximum number of output tokens to return for this wait call.".to_string(), + ), + }, + ), + ( + "terminate".to_string(), + JsonSchema::Boolean { + description: Some("Whether to terminate the running exec session.".to_string()), + }, + ), + ]); + + ToolSpec::Function(ResponsesApiTool { + name: WAIT_TOOL_NAME.to_string(), + description: format!( + "Waits on a yielded `{PUBLIC_TOOL_NAME}` session and returns new output or completion." + ), + strict: false, + parameters: JsonSchema::Object { + properties, + required: Some(vec!["session_id".to_string()]), + additional_properties: Some(false.into()), + }, + output_schema: None, + }) +} + fn create_shell_tool(request_permission_enabled: bool) -> ToolSpec { let mut properties = BTreeMap::from([ ( @@ -1660,7 +1710,7 @@ source: /[\s\S]+/ enabled_tool_names.join(", ") }; let description = format!( - "Runs JavaScript in a Node-backed `node:vm` context. This is a freeform tool: send raw JavaScript source text (no JSON/quotes/markdown fences). Direct tool calls remain available while `{PUBLIC_TOOL_NAME}` is enabled. Inside JavaScript, import nested tools from `tools.js`, for example `import {{ exec_command }} from \"tools.js\"` or `import {{ ALL_TOOLS }} from \"tools.js\"` to inspect the available `{{ module, name, description }}` entries. Namespaced tools are also available from `tools/.js`; MCP tools use `tools/mcp/.js`, for example `import {{ append_notebook_logs_chart }} from \"tools/mcp/ologs.js\"`. Nested tool calls resolve to their code-mode result values. Import `{{ output_text, output_image, set_max_output_tokens_per_exec_call, store, load }}` from `\"@openai/code_mode\"` (or `\"openai/code_mode\"`); `output_text(value)` surfaces text back to the model and stringifies non-string objects when possible, `output_image(imageUrl)` appends an `input_image` content item for `http(s)` or `data:` URLs, `store(key, value)` persists JSON-serializable values across `{PUBLIC_TOOL_NAME}` calls in the current session, `load(key)` returns a cloned stored value or `undefined`, and `set_max_output_tokens_per_exec_call(value)` sets the token budget used to truncate the final Rust-side result of the current `{PUBLIC_TOOL_NAME}` execution. The default is `10000`. This guards the overall `{PUBLIC_TOOL_NAME}` output, not individual nested tool invocations. The returned content starts with a separate `Script completed` or `Script failed` text item that includes wall time. When truncation happens, the final text may include `Total output lines:` and the usual `…N tokens truncated…` marker. Function tools require JSON object arguments. Freeform tools require raw strings. `add_content(value)` remains available for compatibility with a content item, content-item array, or string. Structured nested-tool results should be converted to text first, for example with `JSON.stringify(...)`. Only content passed to `output_text(...)`, `output_image(...)`, or `add_content(value)` is surfaced back to the model. Enabled nested tools: {enabled_list}." + "Runs JavaScript in a Node-backed `node:vm` context. This is a freeform tool: send raw JavaScript source text (no JSON/quotes/markdown fences). Direct tool calls remain available while `{PUBLIC_TOOL_NAME}` is enabled. Inside JavaScript, import nested tools from `tools.js`, for example `import {{ exec_command }} from \"tools.js\"` or `import {{ ALL_TOOLS }} from \"tools.js\"` to inspect the available `{{ module, name, description }}` entries. Namespaced tools are also available from `tools/.js`; MCP tools use `tools/mcp/.js`, for example `import {{ append_notebook_logs_chart }} from \"tools/mcp/ologs.js\"`. Nested tool calls resolve to their code-mode result values. Import `{{ output_text, output_image, set_max_output_tokens_per_exec_call, set_yield_time, store, load }}` from `\"@openai/code_mode\"` (or `\"openai/code_mode\"`); `output_text(value)` surfaces text back to the model and stringifies non-string objects when possible, `output_image(imageUrl)` appends an `input_image` content item for `http(s)` or `data:` URLs, `store(key, value)` persists JSON-serializable values across `{PUBLIC_TOOL_NAME}` calls in the current session, `load(key)` returns a cloned stored value or `undefined`, `set_max_output_tokens_per_exec_call(value)` sets the token budget used to truncate direct `{PUBLIC_TOOL_NAME}` returns, and `{WAIT_TOOL_NAME}` uses its own `max_tokens` argument with a default of `10000`. `set_yield_time(value)` asks `{PUBLIC_TOOL_NAME}` to return early if the script is still running after that many milliseconds so `{WAIT_TOOL_NAME}` can resume it later. The default wait timeout for `{WAIT_TOOL_NAME}` is {DEFAULT_WAIT_YIELD_TIME_MS}. The returned content starts with a separate `Script completed`, `Script failed`, or `Script running with session ID …` text item that includes wall time. When truncation happens, the final text may include `Total output lines:` and the usual `…N tokens truncated…` marker. Function tools require JSON object arguments. Freeform tools require raw strings. `add_content(value)` remains available for compatibility with a content item, content-item array, or string. Structured nested-tool results should be converted to text first, for example with `JSON.stringify(...)`. Only content passed to `output_text(...)`, `output_image(...)`, or `add_content(value)` is surfaced back to the model. Enabled nested tools: {enabled_list}." ); ToolSpec::Freeform(FreeformTool { @@ -1675,7 +1725,9 @@ source: /[\s\S]+/ } fn is_code_mode_nested_tool(spec: &ToolSpec) -> bool { - spec.name() != PUBLIC_TOOL_NAME && matches!(spec, ToolSpec::Function(_) | ToolSpec::Freeform(_)) + spec.name() != PUBLIC_TOOL_NAME + && spec.name() != WAIT_TOOL_NAME + && matches!(spec, ToolSpec::Function(_) | ToolSpec::Freeform(_)) } fn create_list_mcp_resources_tool() -> ToolSpec { @@ -2030,6 +2082,7 @@ pub(crate) fn build_specs( use crate::tools::handlers::ApplyPatchHandler; use crate::tools::handlers::ArtifactsHandler; use crate::tools::handlers::CodeModeHandler; + use crate::tools::handlers::CodeModeWaitHandler; use crate::tools::handlers::DynamicToolHandler; use crate::tools::handlers::GrepFilesHandler; use crate::tools::handlers::JsReplHandler; @@ -2067,6 +2120,7 @@ pub(crate) fn build_specs( }); let search_tool_handler = Arc::new(SearchToolBm25Handler); let code_mode_handler = Arc::new(CodeModeHandler); + let code_mode_wait_handler = Arc::new(CodeModeWaitHandler); let js_repl_handler = Arc::new(JsReplHandler); let js_repl_reset_handler = Arc::new(JsReplResetHandler); let artifacts_handler = Arc::new(ArtifactsHandler); @@ -2096,6 +2150,13 @@ pub(crate) fn build_specs( config.code_mode_enabled, ); builder.register_handler(PUBLIC_TOOL_NAME, code_mode_handler); + push_tool_spec( + &mut builder, + create_exec_wait_tool(), + false, + config.code_mode_enabled, + ); + builder.register_handler(WAIT_TOOL_NAME, code_mode_wait_handler); } match &config.shell_type { diff --git a/codex-rs/core/tests/suite/code_mode.rs b/codex-rs/core/tests/suite/code_mode.rs index 07cadc3431e..3112ae1dce5 100644 --- a/codex-rs/core/tests/suite/code_mode.rs +++ b/codex-rs/core/tests/suite/code_mode.rs @@ -21,6 +21,7 @@ use pretty_assertions::assert_eq; use serde_json::Value; use std::collections::HashMap; use std::fs; +use std::path::Path; use std::time::Duration; use wiremock::MockServer; @@ -32,6 +33,16 @@ fn custom_tool_output_items(req: &ResponsesRequest, call_id: &str) -> Vec .clone() } +fn function_tool_output_items(req: &ResponsesRequest, call_id: &str) -> Vec { + match req.function_call_output(call_id).get("output") { + Some(Value::Array(items)) => items.clone(), + Some(Value::String(text)) => { + vec![serde_json::json!({ "type": "input_text", "text": text })] + } + _ => panic!("function tool output should be serialized as text or content items"), + } +} + fn text_item(items: &[Value], index: usize) -> &str { items[index] .get("text") @@ -39,6 +50,20 @@ fn text_item(items: &[Value], index: usize) -> &str { .expect("content item should be input_text") } +fn extract_running_session_id(text: &str) -> i32 { + text.strip_prefix("Script running with session ID ") + .and_then(|rest| rest.split('\n').next()) + .expect("running header should contain a session ID") + .parse() + .expect("session ID should parse as i32") +} + +fn wait_for_file_source(path: &Path) -> Result { + let quoted_path = shlex::try_join([path.to_string_lossy().as_ref()])?; + let command = format!("if [ -f {quoted_path} ]; then printf ready; fi"); + Ok(format!("await waitForFile({command:?});")) +} + fn custom_tool_output_body_and_success( req: &ResponsesRequest, call_id: &str, @@ -289,6 +314,775 @@ Error:\ boom\n Ok(()) } +#[cfg_attr(windows, ignore = "no exec_command on Windows")] +#[tokio::test(flavor = "multi_thread", worker_threads = 2)] +async fn code_mode_can_yield_and_resume_with_exec_wait() -> Result<()> { + skip_if_no_network!(Ok(())); + + let server = responses::start_mock_server().await; + let mut builder = test_codex().with_config(move |config| { + let _ = config.features.enable(Feature::CodeMode); + }); + let test = builder.build(&server).await?; + let phase_2_gate = test.workspace_path("code-mode-phase-2.ready"); + let phase_3_gate = test.workspace_path("code-mode-phase-3.ready"); + let phase_2_wait = wait_for_file_source(&phase_2_gate)?; + let phase_3_wait = wait_for_file_source(&phase_3_gate)?; + + let code = format!( + r#" +import {{ output_text, set_yield_time }} from "@openai/code_mode"; +import {{ exec_command }} from "tools.js"; + +const waitForFile = async (cmd) => {{ + while ((await exec_command({{ cmd }})).output !== "ready") {{ + }} +}}; + +output_text("phase 1"); +set_yield_time(10); +{phase_2_wait} +output_text("phase 2"); +{phase_3_wait} +output_text("phase 3"); +"# + ); + + responses::mount_sse_once( + &server, + sse(vec![ + ev_response_created("resp-1"), + ev_custom_tool_call("call-1", "exec", &code), + ev_completed("resp-1"), + ]), + ) + .await; + let first_completion = responses::mount_sse_once( + &server, + sse(vec![ + ev_assistant_message("msg-1", "waiting"), + ev_completed("resp-2"), + ]), + ) + .await; + + test.submit_turn("start the long exec").await?; + + let first_request = first_completion.single_request(); + let first_items = custom_tool_output_items(&first_request, "call-1"); + assert_eq!(first_items.len(), 2); + assert_regex_match( + concat!( + r"(?s)\A", + r"Script running with session ID \d+\nWall time \d+\.\d seconds\nOutput:\n\z" + ), + text_item(&first_items, 0), + ); + assert_eq!(text_item(&first_items, 1), "phase 1"); + let session_id = extract_running_session_id(text_item(&first_items, 0)); + + responses::mount_sse_once( + &server, + sse(vec![ + ev_response_created("resp-3"), + responses::ev_function_call( + "call-2", + "exec_wait", + &serde_json::to_string(&serde_json::json!({ + "session_id": session_id, + "yield_time_ms": 1_000, + }))?, + ), + ev_completed("resp-3"), + ]), + ) + .await; + let second_completion = responses::mount_sse_once( + &server, + sse(vec![ + ev_assistant_message("msg-2", "still waiting"), + ev_completed("resp-4"), + ]), + ) + .await; + + fs::write(&phase_2_gate, "ready")?; + test.submit_turn("wait again").await?; + + let second_request = second_completion.single_request(); + let second_items = function_tool_output_items(&second_request, "call-2"); + assert_eq!(second_items.len(), 2); + assert_regex_match( + concat!( + r"(?s)\A", + r"Script running with session ID \d+\nWall time \d+\.\d seconds\nOutput:\n\z" + ), + text_item(&second_items, 0), + ); + assert_eq!( + extract_running_session_id(text_item(&second_items, 0)), + session_id + ); + assert_eq!(text_item(&second_items, 1), "phase 2"); + + responses::mount_sse_once( + &server, + sse(vec![ + ev_response_created("resp-5"), + responses::ev_function_call( + "call-3", + "exec_wait", + &serde_json::to_string(&serde_json::json!({ + "session_id": session_id, + "yield_time_ms": 1_000, + }))?, + ), + ev_completed("resp-5"), + ]), + ) + .await; + let third_completion = responses::mount_sse_once( + &server, + sse(vec![ + ev_assistant_message("msg-3", "done"), + ev_completed("resp-6"), + ]), + ) + .await; + + fs::write(&phase_3_gate, "ready")?; + test.submit_turn("wait for completion").await?; + + let third_request = third_completion.single_request(); + let third_items = function_tool_output_items(&third_request, "call-3"); + assert_eq!(third_items.len(), 2); + assert_regex_match( + concat!( + r"(?s)\A", + r"Script completed\nWall time \d+\.\d seconds\nOutput:\n\z" + ), + text_item(&third_items, 0), + ); + assert_eq!(text_item(&third_items, 1), "phase 3"); + + Ok(()) +} + +#[cfg_attr(windows, ignore = "no exec_command on Windows")] +#[tokio::test(flavor = "multi_thread", worker_threads = 2)] +async fn code_mode_can_run_multiple_yielded_sessions() -> Result<()> { + skip_if_no_network!(Ok(())); + + let server = responses::start_mock_server().await; + let mut builder = test_codex().with_config(move |config| { + let _ = config.features.enable(Feature::CodeMode); + }); + let test = builder.build(&server).await?; + let session_a_gate = test.workspace_path("code-mode-session-a.ready"); + let session_b_gate = test.workspace_path("code-mode-session-b.ready"); + let session_a_wait = wait_for_file_source(&session_a_gate)?; + let session_b_wait = wait_for_file_source(&session_b_gate)?; + + let session_a_code = format!( + r#" +import {{ output_text, set_yield_time }} from "@openai/code_mode"; +import {{ exec_command }} from "tools.js"; + +const waitForFile = async (cmd) => {{ + while ((await exec_command({{ cmd }})).output !== "ready") {{ + }} +}}; + +output_text("session a start"); +set_yield_time(10); +{session_a_wait} +output_text("session a done"); +"# + ); + let session_b_code = format!( + r#" +import {{ output_text, set_yield_time }} from "@openai/code_mode"; +import {{ exec_command }} from "tools.js"; + +const waitForFile = async (cmd) => {{ + while ((await exec_command({{ cmd }})).output !== "ready") {{ + }} +}}; + +output_text("session b start"); +set_yield_time(10); +{session_b_wait} +output_text("session b done"); +"# + ); + + responses::mount_sse_once( + &server, + sse(vec![ + ev_response_created("resp-1"), + ev_custom_tool_call("call-1", "exec", &session_a_code), + ev_completed("resp-1"), + ]), + ) + .await; + let first_completion = responses::mount_sse_once( + &server, + sse(vec![ + ev_assistant_message("msg-1", "session a waiting"), + ev_completed("resp-2"), + ]), + ) + .await; + + test.submit_turn("start session a").await?; + + let first_request = first_completion.single_request(); + let first_items = custom_tool_output_items(&first_request, "call-1"); + assert_eq!(first_items.len(), 2); + let session_a_id = extract_running_session_id(text_item(&first_items, 0)); + assert_eq!(text_item(&first_items, 1), "session a start"); + + responses::mount_sse_once( + &server, + sse(vec![ + ev_response_created("resp-3"), + ev_custom_tool_call("call-2", "exec", &session_b_code), + ev_completed("resp-3"), + ]), + ) + .await; + let second_completion = responses::mount_sse_once( + &server, + sse(vec![ + ev_assistant_message("msg-2", "session b waiting"), + ev_completed("resp-4"), + ]), + ) + .await; + + test.submit_turn("start session b").await?; + + let second_request = second_completion.single_request(); + let second_items = custom_tool_output_items(&second_request, "call-2"); + assert_eq!(second_items.len(), 2); + let session_b_id = extract_running_session_id(text_item(&second_items, 0)); + assert_eq!(text_item(&second_items, 1), "session b start"); + assert_ne!(session_a_id, session_b_id); + + fs::write(&session_a_gate, "ready")?; + responses::mount_sse_once( + &server, + sse(vec![ + ev_response_created("resp-5"), + responses::ev_function_call( + "call-3", + "exec_wait", + &serde_json::to_string(&serde_json::json!({ + "session_id": session_a_id, + "yield_time_ms": 1_000, + }))?, + ), + ev_completed("resp-5"), + ]), + ) + .await; + let third_completion = responses::mount_sse_once( + &server, + sse(vec![ + ev_assistant_message("msg-3", "session a done"), + ev_completed("resp-6"), + ]), + ) + .await; + + test.submit_turn("wait session a").await?; + + let third_request = third_completion.single_request(); + let third_items = function_tool_output_items(&third_request, "call-3"); + assert_eq!(third_items.len(), 2); + assert_regex_match( + concat!( + r"(?s)\A", + r"Script completed\nWall time \d+\.\d seconds\nOutput:\n\z" + ), + text_item(&third_items, 0), + ); + assert_eq!(text_item(&third_items, 1), "session a done"); + + fs::write(&session_b_gate, "ready")?; + responses::mount_sse_once( + &server, + sse(vec![ + ev_response_created("resp-7"), + responses::ev_function_call( + "call-4", + "exec_wait", + &serde_json::to_string(&serde_json::json!({ + "session_id": session_b_id, + "yield_time_ms": 1_000, + }))?, + ), + ev_completed("resp-7"), + ]), + ) + .await; + let fourth_completion = responses::mount_sse_once( + &server, + sse(vec![ + ev_assistant_message("msg-4", "session b done"), + ev_completed("resp-8"), + ]), + ) + .await; + + test.submit_turn("wait session b").await?; + + let fourth_request = fourth_completion.single_request(); + let fourth_items = function_tool_output_items(&fourth_request, "call-4"); + assert_eq!(fourth_items.len(), 2); + assert_regex_match( + concat!( + r"(?s)\A", + r"Script completed\nWall time \d+\.\d seconds\nOutput:\n\z" + ), + text_item(&fourth_items, 0), + ); + assert_eq!(text_item(&fourth_items, 1), "session b done"); + + Ok(()) +} + +#[cfg_attr(windows, ignore = "no exec_command on Windows")] +#[tokio::test(flavor = "multi_thread", worker_threads = 2)] +async fn code_mode_exec_wait_can_terminate_and_continue() -> Result<()> { + skip_if_no_network!(Ok(())); + + let server = responses::start_mock_server().await; + let mut builder = test_codex().with_config(move |config| { + let _ = config.features.enable(Feature::CodeMode); + }); + let test = builder.build(&server).await?; + let termination_gate = test.workspace_path("code-mode-terminate.ready"); + let termination_wait = wait_for_file_source(&termination_gate)?; + + let code = format!( + r#" +import {{ output_text, set_yield_time }} from "@openai/code_mode"; +import {{ exec_command }} from "tools.js"; + +const waitForFile = async (cmd) => {{ + while ((await exec_command({{ cmd }})).output !== "ready") {{ + }} +}}; + +output_text("phase 1"); +set_yield_time(10); +{termination_wait} +output_text("phase 2"); +"# + ); + + responses::mount_sse_once( + &server, + sse(vec![ + ev_response_created("resp-1"), + ev_custom_tool_call("call-1", "exec", &code), + ev_completed("resp-1"), + ]), + ) + .await; + let first_completion = responses::mount_sse_once( + &server, + sse(vec![ + ev_assistant_message("msg-1", "waiting"), + ev_completed("resp-2"), + ]), + ) + .await; + + test.submit_turn("start the long exec").await?; + + let first_request = first_completion.single_request(); + let first_items = custom_tool_output_items(&first_request, "call-1"); + assert_eq!(first_items.len(), 2); + let session_id = extract_running_session_id(text_item(&first_items, 0)); + assert_eq!(text_item(&first_items, 1), "phase 1"); + + responses::mount_sse_once( + &server, + sse(vec![ + ev_response_created("resp-3"), + responses::ev_function_call( + "call-2", + "exec_wait", + &serde_json::to_string(&serde_json::json!({ + "session_id": session_id, + "terminate": true, + }))?, + ), + ev_completed("resp-3"), + ]), + ) + .await; + let second_completion = responses::mount_sse_once( + &server, + sse(vec![ + ev_assistant_message("msg-2", "terminated"), + ev_completed("resp-4"), + ]), + ) + .await; + + test.submit_turn("terminate it").await?; + + let second_request = second_completion.single_request(); + let second_items = function_tool_output_items(&second_request, "call-2"); + assert_eq!(second_items.len(), 1); + assert_regex_match( + concat!( + r"(?s)\A", + r"Script terminated\nWall time \d+\.\d seconds\nOutput:\n\z" + ), + text_item(&second_items, 0), + ); + + responses::mount_sse_once( + &server, + sse(vec![ + ev_response_created("resp-5"), + ev_custom_tool_call( + "call-3", + "exec", + r#" +import { output_text } from "@openai/code_mode"; + +output_text("after terminate"); +"#, + ), + ev_completed("resp-5"), + ]), + ) + .await; + let third_completion = responses::mount_sse_once( + &server, + sse(vec![ + ev_assistant_message("msg-3", "done"), + ev_completed("resp-6"), + ]), + ) + .await; + + test.submit_turn("run another exec").await?; + + let third_request = third_completion.single_request(); + let third_items = custom_tool_output_items(&third_request, "call-3"); + assert_eq!(third_items.len(), 2); + assert_regex_match( + concat!( + r"(?s)\A", + r"Script completed\nWall time \d+\.\d seconds\nOutput:\n\z" + ), + text_item(&third_items, 0), + ); + assert_eq!(text_item(&third_items, 1), "after terminate"); + + Ok(()) +} + +#[cfg_attr(windows, ignore = "no exec_command on Windows")] +#[tokio::test(flavor = "multi_thread", worker_threads = 2)] +async fn code_mode_exec_wait_terminate_returns_completed_session_if_it_finished_in_background() +-> Result<()> { + skip_if_no_network!(Ok(())); + + let server = responses::start_mock_server().await; + let mut builder = test_codex().with_config(move |config| { + let _ = config.features.enable(Feature::CodeMode); + }); + let test = builder.build(&server).await?; + let session_a_gate = test.workspace_path("code-mode-session-a-finished.ready"); + let session_b_gate = test.workspace_path("code-mode-session-b-blocked.ready"); + let session_a_wait = wait_for_file_source(&session_a_gate)?; + let session_b_wait = wait_for_file_source(&session_b_gate)?; + + let session_a_code = format!( + r#" +import {{ output_text, set_yield_time }} from "@openai/code_mode"; +import {{ exec_command }} from "tools.js"; + +const waitForFile = async (cmd) => {{ + while ((await exec_command({{ cmd }})).output !== "ready") {{ + }} +}}; + +output_text("session a start"); +set_yield_time(10); +{session_a_wait} +output_text("session a done"); +"# + ); + let session_b_code = format!( + r#" +import {{ output_text, set_yield_time }} from "@openai/code_mode"; +import {{ exec_command }} from "tools.js"; + +const waitForFile = async (cmd) => {{ + while ((await exec_command({{ cmd }})).output !== "ready") {{ + }} +}}; + +output_text("session b start"); +set_yield_time(10); +{session_b_wait} +output_text("session b done"); +"# + ); + + responses::mount_sse_once( + &server, + sse(vec![ + ev_response_created("resp-1"), + ev_custom_tool_call("call-1", "exec", &session_a_code), + ev_completed("resp-1"), + ]), + ) + .await; + let first_completion = responses::mount_sse_once( + &server, + sse(vec![ + ev_assistant_message("msg-1", "session a waiting"), + ev_completed("resp-2"), + ]), + ) + .await; + + test.submit_turn("start session a").await?; + + let first_request = first_completion.single_request(); + let first_items = custom_tool_output_items(&first_request, "call-1"); + assert_eq!(first_items.len(), 2); + let session_a_id = extract_running_session_id(text_item(&first_items, 0)); + assert_eq!(text_item(&first_items, 1), "session a start"); + + responses::mount_sse_once( + &server, + sse(vec![ + ev_response_created("resp-3"), + ev_custom_tool_call("call-2", "exec", &session_b_code), + ev_completed("resp-3"), + ]), + ) + .await; + let second_completion = responses::mount_sse_once( + &server, + sse(vec![ + ev_assistant_message("msg-2", "session b waiting"), + ev_completed("resp-4"), + ]), + ) + .await; + + test.submit_turn("start session b").await?; + + let second_request = second_completion.single_request(); + let second_items = custom_tool_output_items(&second_request, "call-2"); + assert_eq!(second_items.len(), 2); + let session_b_id = extract_running_session_id(text_item(&second_items, 0)); + assert_eq!(text_item(&second_items, 1), "session b start"); + + fs::write(&session_a_gate, "ready")?; + responses::mount_sse_once( + &server, + sse(vec![ + ev_response_created("resp-5"), + responses::ev_function_call( + "call-3", + "exec_wait", + &serde_json::to_string(&serde_json::json!({ + "session_id": session_b_id, + "yield_time_ms": 1_000, + }))?, + ), + ev_completed("resp-5"), + ]), + ) + .await; + let third_completion = responses::mount_sse_once( + &server, + sse(vec![ + ev_assistant_message("msg-3", "session b still waiting"), + ev_completed("resp-6"), + ]), + ) + .await; + + test.submit_turn("wait session b").await?; + + let third_request = third_completion.single_request(); + let third_items = function_tool_output_items(&third_request, "call-3"); + assert_eq!(third_items.len(), 1); + assert_regex_match( + concat!( + r"(?s)\A", + r"Script running with session ID \d+\nWall time \d+\.\d seconds\nOutput:\n\z" + ), + text_item(&third_items, 0), + ); + assert_eq!( + extract_running_session_id(text_item(&third_items, 0)), + session_b_id + ); + + responses::mount_sse_once( + &server, + sse(vec![ + ev_response_created("resp-7"), + responses::ev_function_call( + "call-4", + "exec_wait", + &serde_json::to_string(&serde_json::json!({ + "session_id": session_a_id, + "terminate": true, + }))?, + ), + ev_completed("resp-7"), + ]), + ) + .await; + let fourth_completion = responses::mount_sse_once( + &server, + sse(vec![ + ev_assistant_message("msg-4", "session a already done"), + ev_completed("resp-8"), + ]), + ) + .await; + + test.submit_turn("terminate session a").await?; + + let fourth_request = fourth_completion.single_request(); + let fourth_items = function_tool_output_items(&fourth_request, "call-4"); + assert_eq!(fourth_items.len(), 2); + assert_regex_match( + concat!( + r"(?s)\A", + r"Script completed\nWall time \d+\.\d seconds\nOutput:\n\z" + ), + text_item(&fourth_items, 0), + ); + assert_eq!(text_item(&fourth_items, 1), "session a done"); + + Ok(()) +} + +#[cfg_attr(windows, ignore = "no exec_command on Windows")] +#[tokio::test(flavor = "multi_thread", worker_threads = 2)] +async fn code_mode_exec_wait_uses_its_own_max_tokens_budget() -> Result<()> { + skip_if_no_network!(Ok(())); + + let server = responses::start_mock_server().await; + let mut builder = test_codex().with_config(move |config| { + let _ = config.features.enable(Feature::CodeMode); + }); + let test = builder.build(&server).await?; + let completion_gate = test.workspace_path("code-mode-max-tokens.ready"); + let completion_wait = wait_for_file_source(&completion_gate)?; + + let code = format!( + r#" +import {{ output_text, set_max_output_tokens_per_exec_call, set_yield_time }} from "@openai/code_mode"; +import {{ exec_command }} from "tools.js"; + +const waitForFile = async (cmd) => {{ + while ((await exec_command({{ cmd }})).output !== "ready") {{ + }} +}}; + +output_text("phase 1"); +set_max_output_tokens_per_exec_call(100); +set_yield_time(10); +{completion_wait} +output_text("token one token two token three token four token five token six token seven"); +"# + ); + + responses::mount_sse_once( + &server, + sse(vec![ + ev_response_created("resp-1"), + ev_custom_tool_call("call-1", "exec", &code), + ev_completed("resp-1"), + ]), + ) + .await; + let first_completion = responses::mount_sse_once( + &server, + sse(vec![ + ev_assistant_message("msg-1", "waiting"), + ev_completed("resp-2"), + ]), + ) + .await; + + test.submit_turn("start the long exec").await?; + + let first_request = first_completion.single_request(); + let first_items = custom_tool_output_items(&first_request, "call-1"); + assert_eq!(first_items.len(), 2); + assert_eq!(text_item(&first_items, 1), "phase 1"); + let session_id = extract_running_session_id(text_item(&first_items, 0)); + + fs::write(&completion_gate, "ready")?; + responses::mount_sse_once( + &server, + sse(vec![ + ev_response_created("resp-3"), + responses::ev_function_call( + "call-2", + "exec_wait", + &serde_json::to_string(&serde_json::json!({ + "session_id": session_id, + "yield_time_ms": 1_000, + "max_tokens": 6, + }))?, + ), + ev_completed("resp-3"), + ]), + ) + .await; + let second_completion = responses::mount_sse_once( + &server, + sse(vec![ + ev_assistant_message("msg-2", "done"), + ev_completed("resp-4"), + ]), + ) + .await; + + test.submit_turn("wait for completion").await?; + + let second_request = second_completion.single_request(); + let second_items = function_tool_output_items(&second_request, "call-2"); + assert_eq!(second_items.len(), 2); + assert_regex_match( + concat!( + r"(?s)\A", + r"Script completed\nWall time \d+\.\d seconds\nOutput:\n\z" + ), + text_item(&second_items, 0), + ); + let expected_pattern = r#"(?sx) +\A +Total\ output\ lines:\ 1\n +\n +.*…\d+\ tokens\ truncated….* +\z +"#; + assert_regex_match(expected_pattern, text_item(&second_items, 1)); + + Ok(()) +} + #[tokio::test(flavor = "multi_thread", worker_threads = 2)] async fn code_mode_can_output_serialized_text_via_openai_code_mode_module() -> Result<()> { skip_if_no_network!(Ok(())); From b44eb3bb3509893d48c4aa7f029948d48c24ae10 Mon Sep 17 00:00:00 2001 From: pakrym-oai Date: Wed, 11 Mar 2026 17:39:47 -0700 Subject: [PATCH 3/7] Simplify code-mode shared session state --- codex-rs/core/src/state/service.rs | 34 ++++---- codex-rs/core/src/tools/code_mode.rs | 116 +++++++++------------------ 2 files changed, 52 insertions(+), 98 deletions(-) diff --git a/codex-rs/core/src/state/service.rs b/codex-rs/core/src/state/service.rs index cc4791ae23b..8621d7dc316 100644 --- a/codex-rs/core/src/state/service.rs +++ b/codex-rs/core/src/state/service.rs @@ -1,4 +1,5 @@ use std::collections::HashMap; +use std::collections::HashSet; use std::sync::Arc; use crate::AuthManager; @@ -16,7 +17,6 @@ use crate::plugins::PluginsManager; use crate::skills::SkillsManager; use crate::state_db::StateDbHandle; use crate::tools::code_mode::CodeModeProcess; -use crate::tools::code_mode::CodeModeYieldedSession; use crate::tools::network_approval::NetworkApprovalService; use crate::tools::runtimes::ExecveSessionApproval; use crate::tools::sandboxing::ApprovalStore; @@ -33,8 +33,8 @@ use tokio_util::sync::CancellationToken; pub(crate) struct CodeModeStoreService { stored_values: Mutex>, - process: Mutex>>>, - yielded_sessions: Mutex>, + process: Arc>>, + yielded_sessions: Mutex>, next_session_id: Mutex, } @@ -42,8 +42,8 @@ impl Default for CodeModeStoreService { fn default() -> Self { Self { stored_values: Mutex::new(HashMap::new()), - process: Mutex::new(None), - yielded_sessions: Mutex::new(HashMap::new()), + process: Arc::new(Mutex::new(None)), + yielded_sessions: Mutex::new(HashSet::new()), next_session_id: Mutex::new(1), } } @@ -58,12 +58,8 @@ impl CodeModeStoreService { *self.stored_values.lock().await = values; } - pub(crate) async fn store_process(&self, process: Arc>) { - *self.process.lock().await = Some(process); - } - - pub(crate) async fn process(&self) -> Option>> { - self.process.lock().await.clone() + pub(crate) fn process(&self) -> Arc>> { + self.process.clone() } pub(crate) async fn allocate_session_id(&self) -> i32 { @@ -73,18 +69,16 @@ impl CodeModeStoreService { session_id } - pub(crate) async fn store_yielded_session(&self, yielded_session: CodeModeYieldedSession) { + pub(crate) async fn store_yielded_session(&self, session_id: i32) { + self.yielded_sessions.lock().await.insert(session_id); + } + + pub(crate) async fn take_yielded_session(&self, session_id: i32) -> Option { self.yielded_sessions .lock() .await - .insert(yielded_session.session_id, yielded_session); - } - - pub(crate) async fn take_yielded_session( - &self, - session_id: i32, - ) -> Option { - self.yielded_sessions.lock().await.remove(&session_id) + .remove(&session_id) + .then_some(session_id) } } diff --git a/codex-rs/core/src/tools/code_mode.rs b/codex-rs/core/src/tools/code_mode.rs index 945310cb844..40ea237fed2 100644 --- a/codex-rs/core/src/tools/code_mode.rs +++ b/codex-rs/core/src/tools/code_mode.rs @@ -78,11 +78,6 @@ impl CodeModeProcess { } } -#[derive(Clone, Debug)] -pub(crate) struct CodeModeYieldedSession { - pub(crate) session_id: i32, -} - #[derive(Clone, Copy, Debug, Deserialize, Eq, PartialEq, Serialize)] #[serde(rename_all = "snake_case")] enum CodeModeToolKind { @@ -176,7 +171,7 @@ enum CodeModeSessionProgress { Finished(FunctionToolOutput), Yielded { output: FunctionToolOutput, - yielded_session: CodeModeYieldedSession, + session_id: i32, }, } @@ -241,14 +236,26 @@ pub(crate) async fn execute( .code_mode_store .allocate_session_id() .await; - let process = ensure_shared_code_mode_process(&exec) - .await - .map_err(FunctionCallError::RespondToModel)?; let result = { - let mut process = process.lock().await; + let process_slot = exec.session.services.code_mode_store.process(); + let mut process_slot = process_slot.lock().await; + let needs_spawn = match process_slot.as_mut() { + Some(process) => !matches!(process.has_exited(), Ok(false)), + None => true, + }; + if needs_spawn { + *process_slot = Some( + spawn_code_mode_process(&exec) + .await + .map_err(FunctionCallError::RespondToModel)?, + ); + } + let process = process_slot.as_mut().ok_or_else(|| { + FunctionCallError::RespondToModel(format!("{PUBLIC_TOOL_NAME} runner failed to start")) + })?; drive_code_mode_session( &exec, - &mut process, + process, session_id, CodeModeSessionAction::Start { enabled_tools, @@ -258,14 +265,11 @@ pub(crate) async fn execute( ) .await }; - if let Ok(CodeModeSessionProgress::Yielded { - yielded_session, .. - }) = &result - { + if let Ok(CodeModeSessionProgress::Yielded { session_id, .. }) = &result { exec.session .services .code_mode_store - .store_yielded_session(yielded_session.clone()) + .store_yielded_session(*session_id) .await; } match result { @@ -289,7 +293,7 @@ pub(crate) async fn wait( turn, tracker, }; - let yielded_session = exec + let yielded_session_id = exec .session .services .code_mode_store @@ -301,13 +305,23 @@ pub(crate) async fn wait( )) })?; - let process = existing_shared_code_mode_process(&exec).await?; let result = { - let mut process = process.lock().await; + let process_slot = exec.session.services.code_mode_store.process(); + let mut process_slot = process_slot.lock().await; + let process = process_slot.as_mut().ok_or_else(|| { + FunctionCallError::RespondToModel(format!( + "{PUBLIC_TOOL_NAME} runner is not available for {WAIT_TOOL_NAME}" + )) + })?; + if !matches!(process.has_exited(), Ok(false)) { + return Err(FunctionCallError::RespondToModel(format!( + "{PUBLIC_TOOL_NAME} runner is not available for {WAIT_TOOL_NAME}" + ))); + } drive_code_mode_session( &exec, - &mut process, - yielded_session.session_id, + process, + yielded_session_id, if terminate { CodeModeSessionAction::Terminate { max_output_tokens } } else { @@ -319,14 +333,11 @@ pub(crate) async fn wait( ) .await }; - if let Ok(CodeModeSessionProgress::Yielded { - yielded_session, .. - }) = &result - { + if let Ok(CodeModeSessionProgress::Yielded { session_id, .. }) = &result { exec.session .services .code_mode_store - .store_yielded_session(yielded_session.clone()) + .store_yielded_session(*session_id) .await; } @@ -384,57 +395,6 @@ async fn spawn_code_mode_process(exec: &ExecContext) -> Result Result>, String> { - if let Some(process) = exec.session.services.code_mode_store.process().await { - let is_running = { - let mut process_guard = process.lock().await; - matches!(process_guard.has_exited(), Ok(false)) - }; - if is_running { - return Ok(process); - } - } - - let process = Arc::new(tokio::sync::Mutex::new( - spawn_code_mode_process(exec).await?, - )); - exec.session - .services - .code_mode_store - .store_process(process.clone()) - .await; - Ok(process) -} - -async fn existing_shared_code_mode_process( - exec: &ExecContext, -) -> Result>, FunctionCallError> { - let process = exec - .session - .services - .code_mode_store - .process() - .await - .ok_or_else(|| { - FunctionCallError::RespondToModel(format!( - "{PUBLIC_TOOL_NAME} runner is not available for {WAIT_TOOL_NAME}" - )) - })?; - let is_running = { - let mut process_guard = process.lock().await; - matches!(process_guard.has_exited(), Ok(false)) - }; - if is_running { - Ok(process) - } else { - Err(FunctionCallError::RespondToModel(format!( - "{PUBLIC_TOOL_NAME} runner is not available for {WAIT_TOOL_NAME}" - ))) - } -} - async fn drive_code_mode_session( exec: &ExecContext, process: &mut CodeModeProcess, @@ -639,7 +599,7 @@ async fn handle_node_message( ); Ok(Some(CodeModeSessionProgress::Yielded { output: FunctionToolOutput::from_content(delta_items, Some(true)), - yielded_session: CodeModeYieldedSession { session_id }, + session_id, })) } NodeToHostMessage::Terminated { From 97268c2a002a4b24b3006a69df7e7c461b4f5bbb Mon Sep 17 00:00:00 2001 From: pakrym-oai Date: Wed, 11 Mar 2026 20:48:31 -0700 Subject: [PATCH 4/7] Refine code mode session handling --- codex-rs/core/src/codex.rs | 4 +- codex-rs/core/src/codex_tests.rs | 8 +- codex-rs/core/src/state/service.rs | 57 +-- codex-rs/core/src/tools/code_mode.rs | 452 ++++++++----------- codex-rs/core/src/tools/code_mode_runner.cjs | 20 +- codex-rs/core/tests/suite/code_mode.rs | 65 ++- 6 files changed, 268 insertions(+), 338 deletions(-) diff --git a/codex-rs/core/src/codex.rs b/codex-rs/core/src/codex.rs index 2747cd07027..9976e6165a1 100644 --- a/codex-rs/core/src/codex.rs +++ b/codex-rs/core/src/codex.rs @@ -1648,7 +1648,9 @@ impl Session { config.features.enabled(Feature::RuntimeMetrics), Self::build_model_client_beta_features_header(config.as_ref()), ), - code_mode_store: Default::default(), + code_mode_service: crate::tools::code_mode::CodeModeService::new( + config.js_repl_node_path.clone(), + ), }; let js_repl = Arc::new(JsReplHandle::with_node_path( config.js_repl_node_path.clone(), diff --git a/codex-rs/core/src/codex_tests.rs b/codex-rs/core/src/codex_tests.rs index 0265d810b7c..38bc6e5be30 100644 --- a/codex-rs/core/src/codex_tests.rs +++ b/codex-rs/core/src/codex_tests.rs @@ -2162,7 +2162,9 @@ pub(crate) async fn make_session_and_context() -> (Session, TurnContext) { config.features.enabled(Feature::RuntimeMetrics), Session::build_model_client_beta_features_header(config.as_ref()), ), - code_mode_store: Default::default(), + code_mode_service: crate::tools::code_mode::CodeModeService::new( + config.js_repl_node_path.clone(), + ), }; let js_repl = Arc::new(JsReplHandle::with_node_path( config.js_repl_node_path.clone(), @@ -2723,7 +2725,9 @@ pub(crate) async fn make_session_and_context_with_dynamic_tools_and_rx( config.features.enabled(Feature::RuntimeMetrics), Session::build_model_client_beta_features_header(config.as_ref()), ), - code_mode_store: Default::default(), + code_mode_service: crate::tools::code_mode::CodeModeService::new( + config.js_repl_node_path.clone(), + ), }; let js_repl = Arc::new(JsReplHandle::with_node_path( config.js_repl_node_path.clone(), diff --git a/codex-rs/core/src/state/service.rs b/codex-rs/core/src/state/service.rs index 8621d7dc316..851618c00e9 100644 --- a/codex-rs/core/src/state/service.rs +++ b/codex-rs/core/src/state/service.rs @@ -1,5 +1,4 @@ use std::collections::HashMap; -use std::collections::HashSet; use std::sync::Arc; use crate::AuthManager; @@ -16,7 +15,7 @@ use crate::models_manager::manager::ModelsManager; use crate::plugins::PluginsManager; use crate::skills::SkillsManager; use crate::state_db::StateDbHandle; -use crate::tools::code_mode::CodeModeProcess; +use crate::tools::code_mode::CodeModeService; use crate::tools::network_approval::NetworkApprovalService; use crate::tools::runtimes::ExecveSessionApproval; use crate::tools::sandboxing::ApprovalStore; @@ -24,64 +23,12 @@ use crate::unified_exec::UnifiedExecProcessManager; use codex_hooks::Hooks; use codex_otel::SessionTelemetry; use codex_utils_absolute_path::AbsolutePathBuf; -use serde_json::Value as JsonValue; use std::path::PathBuf; use tokio::sync::Mutex; use tokio::sync::RwLock; use tokio::sync::watch; use tokio_util::sync::CancellationToken; -pub(crate) struct CodeModeStoreService { - stored_values: Mutex>, - process: Arc>>, - yielded_sessions: Mutex>, - next_session_id: Mutex, -} - -impl Default for CodeModeStoreService { - fn default() -> Self { - Self { - stored_values: Mutex::new(HashMap::new()), - process: Arc::new(Mutex::new(None)), - yielded_sessions: Mutex::new(HashSet::new()), - next_session_id: Mutex::new(1), - } - } -} - -impl CodeModeStoreService { - pub(crate) async fn stored_values(&self) -> HashMap { - self.stored_values.lock().await.clone() - } - - pub(crate) async fn replace_stored_values(&self, values: HashMap) { - *self.stored_values.lock().await = values; - } - - pub(crate) fn process(&self) -> Arc>> { - self.process.clone() - } - - pub(crate) async fn allocate_session_id(&self) -> i32 { - let mut next_session_id = self.next_session_id.lock().await; - let session_id = *next_session_id; - *next_session_id = next_session_id.saturating_add(1); - session_id - } - - pub(crate) async fn store_yielded_session(&self, session_id: i32) { - self.yielded_sessions.lock().await.insert(session_id); - } - - pub(crate) async fn take_yielded_session(&self, session_id: i32) -> Option { - self.yielded_sessions - .lock() - .await - .remove(&session_id) - .then_some(session_id) - } -} - pub(crate) struct SessionServices { pub(crate) mcp_connection_manager: Arc>, pub(crate) mcp_startup_cancellation_token: Mutex, @@ -113,5 +60,5 @@ pub(crate) struct SessionServices { pub(crate) state_db: Option, /// Session-scoped model client shared across turns. pub(crate) model_client: ModelClient, - pub(crate) code_mode_store: CodeModeStoreService, + pub(crate) code_mode_service: CodeModeService, } diff --git a/codex-rs/core/src/tools/code_mode.rs b/codex-rs/core/src/tools/code_mode.rs index 9aea9d68cd7..9971aa31720 100644 --- a/codex-rs/core/src/tools/code_mode.rs +++ b/codex-rs/core/src/tools/code_mode.rs @@ -1,5 +1,6 @@ use std::collections::HashMap; use std::collections::VecDeque; +use std::path::PathBuf; use std::sync::Arc; use std::time::Duration; @@ -7,7 +8,6 @@ use crate::client_common::tools::ToolSpec; use crate::codex::Session; use crate::codex::TurnContext; use crate::config::Config; -use crate::exec_env::create_env; use crate::features::Feature; use crate::function_tool::FunctionCallError; use crate::tools::ToolRouter; @@ -31,7 +31,9 @@ use tokio::io::AsyncBufReadExt; use tokio::io::AsyncReadExt; use tokio::io::AsyncWriteExt; use tokio::io::BufReader; +use tokio::sync::Mutex; use tokio::task::JoinHandle; +use tracing::warn; const CODE_MODE_RUNNER_SOURCE: &str = include_str!("code_mode_runner.cjs"); const CODE_MODE_BRIDGE_SOURCE: &str = include_str!("code_mode_bridge.js"); @@ -50,11 +52,56 @@ pub(crate) struct CodeModeProcess { child: tokio::process::Child, stdin: tokio::process::ChildStdin, stdout_lines: tokio::io::Lines>, - stderr_task: Option>, + stderr_task: Option>, pending_messages: HashMap>, } impl CodeModeProcess { + async fn write(&mut self, message: &HostToNodeMessage) -> Result<(), std::io::Error> { + let line = serde_json::to_string(message).map_err(std::io::Error::other)?; + self.stdin.write_all(line.as_bytes()).await?; + self.stdin.write_all(b"\n").await?; + self.stdin.flush().await?; + Ok(()) + } + + async fn read(&mut self, session_id: i32) -> Result { + if let Some(message) = self + .pending_messages + .get_mut(&session_id) + .and_then(VecDeque::pop_front) + { + return Ok(message); + } + + loop { + let Some(line) = self.stdout_lines.next_line().await? else { + match self.wait_for_exit().await { + Ok(status) => { + self.join_stderr_task().await; + return Err(std::io::Error::other(format!( + "{PUBLIC_TOOL_NAME} runner exited without returning a result (status {status})" + ))); + } + Err(err) => return Err(std::io::Error::other(err)), + } + }; + if line.trim().is_empty() { + continue; + } + let message: NodeToHostMessage = + serde_json::from_str(&line).map_err(std::io::Error::other)?; + let message_session_id = message_session_id(&message); + if message_session_id == session_id { + return Ok(message); + } + self.pending_messages + .entry(message_session_id) + .or_default() + .push_back(message); + } + } + fn has_exited(&mut self) -> Result { self.child .try_wait() @@ -69,12 +116,62 @@ impl CodeModeProcess { .map_err(|err| format!("failed to wait for {PUBLIC_TOOL_NAME} runner: {err}")) } - async fn stderr(&mut self) -> Result { - self.stderr_task - .take() - .ok_or_else(|| format!("{PUBLIC_TOOL_NAME} stderr collector missing"))? - .await - .map_err(|err| format!("failed to collect {PUBLIC_TOOL_NAME} stderr: {err}")) + async fn join_stderr_task(&mut self) { + let Some(stderr_task) = self.stderr_task.take() else { + return; + }; + if let Err(err) = stderr_task.await { + warn!("failed to join {PUBLIC_TOOL_NAME} stderr task: {err}"); + } + } +} + +pub(crate) struct CodeModeService { + js_repl_node_path: Option, + stored_values: Mutex>, + process: Arc>>, + next_session_id: Mutex, +} + +impl CodeModeService { + pub(crate) fn new(js_repl_node_path: Option) -> Self { + Self { + js_repl_node_path, + stored_values: Mutex::new(HashMap::new()), + process: Arc::new(Mutex::new(None)), + next_session_id: Mutex::new(1), + } + } + + pub(crate) async fn stored_values(&self) -> HashMap { + self.stored_values.lock().await.clone() + } + + pub(crate) async fn replace_stored_values(&self, values: HashMap) { + *self.stored_values.lock().await = values; + } + + async fn ensure_started( + &self, + ) -> Result>, String> { + let mut process_slot = self.process.lock().await; + let needs_spawn = match process_slot.as_mut() { + Some(process) => !matches!(process.has_exited(), Ok(false)), + None => true, + }; + if needs_spawn { + let node_path = resolve_compatible_node(self.js_repl_node_path.as_deref()).await?; + *process_slot = Some(spawn_code_mode_process(&node_path).await?); + } + drop(process_slot); + Ok(self.process.clone().lock_owned().await) + } + + pub(crate) async fn allocate_session_id(&self) -> i32 { + let mut next_session_id = self.next_session_id.lock().await; + let session_id = *next_session_id; + *next_session_id = next_session_id.saturating_add(1); + session_id } } @@ -132,14 +229,10 @@ enum NodeToHostMessage { Yielded { session_id: i32, content_items: Vec, - #[serde(default)] - max_output_tokens_per_exec_call: Option, }, Terminated { session_id: i32, content_items: Vec, - #[serde(default)] - max_output_tokens_per_exec_call: Option, }, Result { session_id: i32, @@ -152,27 +245,9 @@ enum NodeToHostMessage { }, } -enum CodeModeSessionAction { - Start { - enabled_tools: Vec, - stored_values: HashMap, - source: String, - }, - Poll { - yield_time_ms: u64, - max_output_tokens: Option, - }, - Terminate { - max_output_tokens: Option, - }, -} - enum CodeModeSessionProgress { Finished(FunctionToolOutput), - Yielded { - output: FunctionToolOutput, - session_id: i32, - }, + Yielded { output: FunctionToolOutput }, } enum CodeModeExecutionStatus { @@ -228,53 +303,38 @@ pub(crate) async fn execute( tracker, }; let enabled_tools = build_enabled_tools(&exec).await; - let stored_values = exec.session.services.code_mode_store.stored_values().await; + let service = &exec.session.services.code_mode_service; + let stored_values = service.stored_values().await; let source = build_source(&code, &enabled_tools).map_err(FunctionCallError::RespondToModel)?; - let session_id = exec - .session - .services - .code_mode_store - .allocate_session_id() - .await; + let session_id = service.allocate_session_id().await; + let process_slot = service + .ensure_started() + .await + .map_err(FunctionCallError::RespondToModel)?; let result = { - let process_slot = exec.session.services.code_mode_store.process(); - let mut process_slot = process_slot.lock().await; - let needs_spawn = match process_slot.as_mut() { - Some(process) => !matches!(process.has_exited(), Ok(false)), - None => true, + let mut process_slot = process_slot; + let Some(process) = process_slot.as_mut() else { + return Err(FunctionCallError::RespondToModel(format!( + "{PUBLIC_TOOL_NAME} runner failed to start" + ))); }; - if needs_spawn { - *process_slot = Some( - spawn_code_mode_process(&exec) - .await - .map_err(FunctionCallError::RespondToModel)?, - ); - } - let process = process_slot.as_mut().ok_or_else(|| { - FunctionCallError::RespondToModel(format!("{PUBLIC_TOOL_NAME} runner failed to start")) - })?; drive_code_mode_session( &exec, process, - session_id, - CodeModeSessionAction::Start { + HostToNodeMessage::Start { + session_id, enabled_tools, stored_values, source, }, + None, + false, ) .await }; - if let Ok(CodeModeSessionProgress::Yielded { session_id, .. }) = &result { - exec.session - .services - .code_mode_store - .store_yielded_session(*session_id) - .await; - } match result { Ok(CodeModeSessionProgress::Finished(output)) - | Ok(CodeModeSessionProgress::Yielded { output, .. }) => Ok(output), + | Ok(CodeModeSessionProgress::Yielded { output }) => Ok(output), Err(error) => Err(FunctionCallError::RespondToModel(error)), } } @@ -293,71 +353,53 @@ pub(crate) async fn wait( turn, tracker, }; - let yielded_session_id = exec + let process_slot = exec .session .services - .code_mode_store - .take_yielded_session(session_id) + .code_mode_service + .ensure_started() .await - .ok_or_else(|| { - FunctionCallError::RespondToModel(format!( - "{WAIT_TOOL_NAME} session_id {session_id} is not waiting on {WAIT_TOOL_NAME}" - )) - })?; - + .map_err(FunctionCallError::RespondToModel)?; let result = { - let process_slot = exec.session.services.code_mode_store.process(); - let mut process_slot = process_slot.lock().await; - let process = process_slot.as_mut().ok_or_else(|| { - FunctionCallError::RespondToModel(format!( - "{PUBLIC_TOOL_NAME} runner is not available for {WAIT_TOOL_NAME}" - )) - })?; + let mut process_slot = process_slot; + let Some(process) = process_slot.as_mut() else { + return Err(FunctionCallError::RespondToModel(format!( + "{PUBLIC_TOOL_NAME} runner failed to start" + ))); + }; if !matches!(process.has_exited(), Ok(false)) { return Err(FunctionCallError::RespondToModel(format!( - "{PUBLIC_TOOL_NAME} runner is not available for {WAIT_TOOL_NAME}" + "{PUBLIC_TOOL_NAME} runner failed to start" ))); } drive_code_mode_session( &exec, process, - yielded_session_id, if terminate { - CodeModeSessionAction::Terminate { max_output_tokens } + HostToNodeMessage::Terminate { session_id } } else { - CodeModeSessionAction::Poll { + HostToNodeMessage::Poll { + session_id, yield_time_ms, - max_output_tokens, } }, + Some(max_output_tokens), + terminate, ) .await }; - if let Ok(CodeModeSessionProgress::Yielded { session_id, .. }) = &result { - exec.session - .services - .code_mode_store - .store_yielded_session(*session_id) - .await; - } - match result { Ok(CodeModeSessionProgress::Finished(output)) - | Ok(CodeModeSessionProgress::Yielded { output, .. }) => Ok(output), + | Ok(CodeModeSessionProgress::Yielded { output }) => Ok(output), Err(error) => Err(FunctionCallError::RespondToModel(error)), } } -async fn spawn_code_mode_process(exec: &ExecContext) -> Result { - let node_path = resolve_compatible_node(exec.turn.config.js_repl_node_path.as_deref()).await?; - let env = create_env(&exec.turn.shell_environment_policy, None); +async fn spawn_code_mode_process(node_path: &std::path::Path) -> Result { let mut cmd = tokio::process::Command::new(&node_path); cmd.arg("--experimental-vm-modules"); cmd.arg("--eval"); cmd.arg(CODE_MODE_RUNNER_SOURCE); - cmd.current_dir(&exec.turn.cwd); - cmd.env_clear(); - cmd.envs(env); cmd.stdin(std::process::Stdio::piped()) .stdout(std::process::Stdio::piped()) .stderr(std::process::Stdio::piped()) @@ -382,8 +424,17 @@ async fn spawn_code_mode_process(exec: &ExecContext) -> Result { + let stderr = String::from_utf8_lossy(&buf).trim().to_string(); + if !stderr.is_empty() { + warn!("{PUBLIC_TOOL_NAME} runner stderr: {stderr}"); + } + } + Err(err) => { + warn!("failed to read {PUBLIC_TOOL_NAME} stderr: {err}"); + } + } }); Ok(CodeModeProcess { @@ -398,144 +449,27 @@ async fn spawn_code_mode_process(exec: &ExecContext) -> Result>, + is_terminate: bool, ) -> Result { let started_at = std::time::Instant::now(); - let is_terminate = matches!(action, CodeModeSessionAction::Terminate { .. }); - let (message, poll_max_output_tokens) = match action { - CodeModeSessionAction::Start { - enabled_tools, - stored_values, - source, - } => ( - HostToNodeMessage::Start { - session_id, - enabled_tools, - stored_values, - source, - }, - None, - ), - CodeModeSessionAction::Poll { - yield_time_ms, - max_output_tokens, - } => ( - HostToNodeMessage::Poll { - session_id, - yield_time_ms, - }, - Some(max_output_tokens), - ), - CodeModeSessionAction::Terminate { max_output_tokens } => ( - HostToNodeMessage::Terminate { session_id }, - Some(max_output_tokens), - ), + let session_id = match &message { + HostToNodeMessage::Start { session_id, .. } + | HostToNodeMessage::Poll { session_id, .. } + | HostToNodeMessage::Terminate { session_id } + | HostToNodeMessage::Response { session_id, .. } => *session_id, }; - if let Some(progress) = process_pending_messages( - exec, - process, - session_id, - poll_max_output_tokens, - started_at, - is_terminate, - ) - .await? - { - return Ok(progress); - } - write_message(&mut process.stdin, &message).await?; - - if let Some(progress) = process_pending_messages( - exec, - process, - session_id, - poll_max_output_tokens, - started_at, - is_terminate, - ) - .await? - { - return Ok(progress); - } - - while let Some(line) = process - .stdout_lines - .next_line() + process + .write(&message) .await - .map_err(|err| format!("failed to read {PUBLIC_TOOL_NAME} runner stdout: {err}"))? - { - if line.trim().is_empty() { - continue; - } - let message: NodeToHostMessage = serde_json::from_str(&line).map_err(|err| { - format!("invalid {PUBLIC_TOOL_NAME} runner message: {err}; line={line}") - })?; - let message_session_id = message_session_id(&message); - if message_session_id != session_id { - if let NodeToHostMessage::ToolCall { - session_id: message_session_id, - id, - name, - input, - } = message - { - let response = HostToNodeMessage::Response { - session_id: message_session_id, - id, - code_mode_result: call_nested_tool(exec.clone(), name, input).await, - }; - write_message(&mut process.stdin, &response).await?; - } else { - process - .pending_messages - .entry(message_session_id) - .or_default() - .push_back(message); - } - continue; - } - if let Some(progress) = handle_node_message( - exec, - process, - session_id, - message, - poll_max_output_tokens, - started_at, - is_terminate, - ) - .await? - { - return Ok(progress); - } - } + .map_err(|err| err.to_string())?; - let status = process.wait_for_exit().await?; - let stderr = process.stderr().await?; - let message = if stderr.is_empty() { - format!("{PUBLIC_TOOL_NAME} runner exited without returning a result (status {status})") - } else { - stderr - }; - Err(message) -} - -async fn process_pending_messages( - exec: &ExecContext, - process: &mut CodeModeProcess, - session_id: i32, - poll_max_output_tokens: Option>, - started_at: std::time::Instant, - is_terminate: bool, -) -> Result, String> { loop { - let Some(message) = process - .pending_messages - .get_mut(&session_id) - .and_then(VecDeque::pop_front) - else { - return Ok(None); - }; + let message = process + .read(session_id) + .await + .map_err(|err| err.to_string())?; if let Some(progress) = handle_node_message( exec, process, @@ -547,7 +481,7 @@ async fn process_pending_messages( ) .await? { - return Ok(Some(progress)); + return Ok(progress); } } } @@ -576,22 +510,18 @@ async fn handle_node_message( id, code_mode_result: call_nested_tool(exec.clone(), name, input).await, }; - write_message(&mut process.stdin, &response).await?; + process + .write(&response) + .await + .map_err(|err| err.to_string())?; Ok(None) } - NodeToHostMessage::Yielded { - content_items, - max_output_tokens_per_exec_call, - .. - } => { + NodeToHostMessage::Yielded { content_items, .. } => { if is_terminate { return Ok(None); } let mut delta_items = output_content_items_from_json_values(content_items)?; - delta_items = truncate_code_mode_result( - delta_items, - poll_max_output_tokens.unwrap_or(max_output_tokens_per_exec_call), - ); + delta_items = truncate_code_mode_result(delta_items, poll_max_output_tokens.flatten()); prepend_script_status( &mut delta_items, CodeModeExecutionStatus::Running(session_id), @@ -599,19 +529,11 @@ async fn handle_node_message( ); Ok(Some(CodeModeSessionProgress::Yielded { output: FunctionToolOutput::from_content(delta_items, Some(true)), - session_id, })) } - NodeToHostMessage::Terminated { - content_items, - max_output_tokens_per_exec_call, - .. - } => { + NodeToHostMessage::Terminated { content_items, .. } => { let mut delta_items = output_content_items_from_json_values(content_items)?; - delta_items = truncate_code_mode_result( - delta_items, - poll_max_output_tokens.unwrap_or(max_output_tokens_per_exec_call), - ); + delta_items = truncate_code_mode_result(delta_items, poll_max_output_tokens.flatten()); prepend_script_status( &mut delta_items, CodeModeExecutionStatus::Terminated, @@ -630,7 +552,7 @@ async fn handle_node_message( } => { exec.session .services - .code_mode_store + .code_mode_service .replace_stored_values(stored_values) .await; let mut delta_items = output_content_items_from_json_values(content_items)?; @@ -670,26 +592,6 @@ fn message_session_id(message: &NodeToHostMessage) -> i32 { } } -async fn write_message( - stdin: &mut tokio::process::ChildStdin, - message: &HostToNodeMessage, -) -> Result<(), String> { - let line = serde_json::to_string(message) - .map_err(|err| format!("failed to serialize {PUBLIC_TOOL_NAME} message: {err}"))?; - stdin - .write_all(line.as_bytes()) - .await - .map_err(|err| format!("failed to write {PUBLIC_TOOL_NAME} message: {err}"))?; - stdin - .write_all(b"\n") - .await - .map_err(|err| format!("failed to write {PUBLIC_TOOL_NAME} message newline: {err}"))?; - stdin - .flush() - .await - .map_err(|err| format!("failed to flush {PUBLIC_TOOL_NAME} message: {err}")) -} - fn prepend_script_status( content_items: &mut Vec, status: CodeModeExecutionStatus, diff --git a/codex-rs/core/src/tools/code_mode_runner.cjs b/codex-rs/core/src/tools/code_mode_runner.cjs index 5bad1c61992..d64e369f320 100644 --- a/codex-rs/core/src/tools/code_mode_runner.cjs +++ b/codex-rs/core/src/tools/code_mode_runner.cjs @@ -467,6 +467,15 @@ function createProtocol() { const session = sessions.get(message.session_id); if (session) { schedulePollYield(protocol, session, normalizeYieldTime(message.yield_time_ms ?? 0)); + } else { + void protocol.send({ + type: 'result', + session_id: message.session_id, + content_items: [], + stored_values: {}, + error_text: `exec session ${message.session_id} not found`, + max_output_tokens_per_exec_call: DEFAULT_MAX_OUTPUT_TOKENS_PER_EXEC_CALL, + }); } return; } @@ -475,6 +484,15 @@ function createProtocol() { const session = sessions.get(message.session_id); if (session) { void terminateSession(protocol, sessions, session); + } else { + void protocol.send({ + type: 'result', + session_id: message.session_id, + content_items: [], + stored_values: {}, + error_text: `exec session ${message.session_id} not found`, + max_output_tokens_per_exec_call: DEFAULT_MAX_OUTPUT_TOKENS_PER_EXEC_CALL, + }); } return; } @@ -662,7 +680,6 @@ async function sendYielded(protocol, session) { type: 'yielded', session_id: session.id, content_items: contentItems, - max_output_tokens_per_exec_call: session.max_output_tokens_per_exec_call, }); } @@ -726,7 +743,6 @@ async function terminateSession(protocol, sessions, session) { type: 'terminated', session_id: session.id, content_items: contentItems, - max_output_tokens_per_exec_call: session.max_output_tokens_per_exec_call, }); } diff --git a/codex-rs/core/tests/suite/code_mode.rs b/codex-rs/core/tests/suite/code_mode.rs index 3112ae1dce5..7b775a80e1e 100644 --- a/codex-rs/core/tests/suite/code_mode.rs +++ b/codex-rs/core/tests/suite/code_mode.rs @@ -789,6 +789,66 @@ output_text("after terminate"); Ok(()) } +#[tokio::test(flavor = "multi_thread", worker_threads = 2)] +async fn code_mode_exec_wait_returns_error_for_unknown_session() -> Result<()> { + skip_if_no_network!(Ok(())); + + let server = responses::start_mock_server().await; + let mut builder = test_codex().with_config(move |config| { + let _ = config.features.enable(Feature::CodeMode); + }); + let test = builder.build(&server).await?; + + responses::mount_sse_once( + &server, + sse(vec![ + ev_response_created("resp-1"), + responses::ev_function_call( + "call-1", + "exec_wait", + &serde_json::to_string(&serde_json::json!({ + "session_id": 999_999, + "yield_time_ms": 1_000, + }))?, + ), + ev_completed("resp-1"), + ]), + ) + .await; + let completion = responses::mount_sse_once( + &server, + sse(vec![ + ev_assistant_message("msg-1", "done"), + ev_completed("resp-2"), + ]), + ) + .await; + + test.submit_turn("wait on an unknown exec session").await?; + + let request = completion.single_request(); + let (_, success) = request + .function_call_output_content_and_success("call-1") + .expect("function tool output should be present"); + assert_ne!(success, Some(true)); + + let items = function_tool_output_items(&request, "call-1"); + assert_eq!(items.len(), 2); + assert_regex_match( + concat!( + r"(?s)\A", + r"Script failed\nWall time \d+\.\d seconds\nOutput:\n\z" + ), + text_item(&items, 0), + ); + assert_eq!( + text_item(&items, 1), + "Script error:\nexec session 999999 not found" + ); + + Ok(()) +} + #[cfg_attr(windows, ignore = "no exec_command on Windows")] #[tokio::test(flavor = "multi_thread", worker_threads = 2)] async fn code_mode_exec_wait_terminate_returns_completed_session_if_it_finished_in_background() @@ -962,15 +1022,14 @@ output_text("session b done"); let fourth_request = fourth_completion.single_request(); let fourth_items = function_tool_output_items(&fourth_request, "call-4"); - assert_eq!(fourth_items.len(), 2); + assert_eq!(fourth_items.len(), 1); assert_regex_match( concat!( r"(?s)\A", - r"Script completed\nWall time \d+\.\d seconds\nOutput:\n\z" + r"Script terminated\nWall time \d+\.\d seconds\nOutput:\n\z" ), text_item(&fourth_items, 0), ); - assert_eq!(text_item(&fourth_items, 1), "session a done"); Ok(()) } From 5c06a41d7f7d458060631013d819d194c6618c7f Mon Sep 17 00:00:00 2001 From: pakrym-oai Date: Wed, 11 Mar 2026 20:56:21 -0700 Subject: [PATCH 5/7] Simplify code mode file wait helper --- codex-rs/core/tests/suite/code_mode.rs | 40 +++----------------------- 1 file changed, 4 insertions(+), 36 deletions(-) diff --git a/codex-rs/core/tests/suite/code_mode.rs b/codex-rs/core/tests/suite/code_mode.rs index 7b775a80e1e..23fcd9c0893 100644 --- a/codex-rs/core/tests/suite/code_mode.rs +++ b/codex-rs/core/tests/suite/code_mode.rs @@ -61,7 +61,10 @@ fn extract_running_session_id(text: &str) -> i32 { fn wait_for_file_source(path: &Path) -> Result { let quoted_path = shlex::try_join([path.to_string_lossy().as_ref()])?; let command = format!("if [ -f {quoted_path} ]; then printf ready; fi"); - Ok(format!("await waitForFile({command:?});")) + Ok(format!( + r#"while ((await exec_command({{ cmd: {command:?} }})).output !== "ready") {{ +}}"# + )) } fn custom_tool_output_body_and_success( @@ -334,11 +337,6 @@ async fn code_mode_can_yield_and_resume_with_exec_wait() -> Result<()> { import {{ output_text, set_yield_time }} from "@openai/code_mode"; import {{ exec_command }} from "tools.js"; -const waitForFile = async (cmd) => {{ - while ((await exec_command({{ cmd }})).output !== "ready") {{ - }} -}}; - output_text("phase 1"); set_yield_time(10); {phase_2_wait} @@ -488,11 +486,6 @@ async fn code_mode_can_run_multiple_yielded_sessions() -> Result<()> { import {{ output_text, set_yield_time }} from "@openai/code_mode"; import {{ exec_command }} from "tools.js"; -const waitForFile = async (cmd) => {{ - while ((await exec_command({{ cmd }})).output !== "ready") {{ - }} -}}; - output_text("session a start"); set_yield_time(10); {session_a_wait} @@ -504,11 +497,6 @@ output_text("session a done"); import {{ output_text, set_yield_time }} from "@openai/code_mode"; import {{ exec_command }} from "tools.js"; -const waitForFile = async (cmd) => {{ - while ((await exec_command({{ cmd }})).output !== "ready") {{ - }} -}}; - output_text("session b start"); set_yield_time(10); {session_b_wait} @@ -670,11 +658,6 @@ async fn code_mode_exec_wait_can_terminate_and_continue() -> Result<()> { import {{ output_text, set_yield_time }} from "@openai/code_mode"; import {{ exec_command }} from "tools.js"; -const waitForFile = async (cmd) => {{ - while ((await exec_command({{ cmd }})).output !== "ready") {{ - }} -}}; - output_text("phase 1"); set_yield_time(10); {termination_wait} @@ -870,11 +853,6 @@ async fn code_mode_exec_wait_terminate_returns_completed_session_if_it_finished_ import {{ output_text, set_yield_time }} from "@openai/code_mode"; import {{ exec_command }} from "tools.js"; -const waitForFile = async (cmd) => {{ - while ((await exec_command({{ cmd }})).output !== "ready") {{ - }} -}}; - output_text("session a start"); set_yield_time(10); {session_a_wait} @@ -886,11 +864,6 @@ output_text("session a done"); import {{ output_text, set_yield_time }} from "@openai/code_mode"; import {{ exec_command }} from "tools.js"; -const waitForFile = async (cmd) => {{ - while ((await exec_command({{ cmd }})).output !== "ready") {{ - }} -}}; - output_text("session b start"); set_yield_time(10); {session_b_wait} @@ -1052,11 +1025,6 @@ async fn code_mode_exec_wait_uses_its_own_max_tokens_budget() -> Result<()> { import {{ output_text, set_max_output_tokens_per_exec_call, set_yield_time }} from "@openai/code_mode"; import {{ exec_command }} from "tools.js"; -const waitForFile = async (cmd) => {{ - while ((await exec_command({{ cmd }})).output !== "ready") {{ - }} -}}; - output_text("phase 1"); set_max_output_tokens_per_exec_call(100); set_yield_time(10); From 2ddde61431437a29f658372f177c865a3914d75a Mon Sep 17 00:00:00 2001 From: pakrym-oai Date: Wed, 11 Mar 2026 21:02:35 -0700 Subject: [PATCH 6/7] codex: fix CI failure on PR #14295 --- codex-rs/core/src/tools/code_mode.rs | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/codex-rs/core/src/tools/code_mode.rs b/codex-rs/core/src/tools/code_mode.rs index 9971aa31720..d791c53076a 100644 --- a/codex-rs/core/src/tools/code_mode.rs +++ b/codex-rs/core/src/tools/code_mode.rs @@ -396,7 +396,7 @@ pub(crate) async fn wait( } async fn spawn_code_mode_process(node_path: &std::path::Path) -> Result { - let mut cmd = tokio::process::Command::new(&node_path); + let mut cmd = tokio::process::Command::new(node_path); cmd.arg("--experimental-vm-modules"); cmd.arg("--eval"); cmd.arg(CODE_MODE_RUNNER_SOURCE); From a0072cf521a452088e50c91e8fb4f9d310bc048a Mon Sep 17 00:00:00 2001 From: pakrym-oai Date: Wed, 11 Mar 2026 23:14:52 -0700 Subject: [PATCH 7/7] Refactor code mode worker dispatch --- codex-rs/core/src/codex.rs | 5 + codex-rs/core/src/tools/code_mode.rs | 537 +++++++++++-------- codex-rs/core/src/tools/code_mode_runner.cjs | 63 ++- 3 files changed, 366 insertions(+), 239 deletions(-) diff --git a/codex-rs/core/src/codex.rs b/codex-rs/core/src/codex.rs index 9976e6165a1..cc01bf6b690 100644 --- a/codex-rs/core/src/codex.rs +++ b/codex-rs/core/src/codex.rs @@ -5464,6 +5464,11 @@ pub(crate) async fn run_turn( // Although from the perspective of codex.rs, TurnDiffTracker has the lifecycle of a Task which contains // many turns, from the perspective of the user, it is a single turn. let turn_diff_tracker = Arc::new(tokio::sync::Mutex::new(TurnDiffTracker::new())); + let _code_mode_worker = sess + .services + .code_mode_service + .start_turn_worker(&sess, &turn_context, &turn_diff_tracker) + .await; let mut server_model_warning_emitted_for_turn = false; // `ModelClientSession` is turn-scoped and caches WebSocket + sticky routing state, so we reuse diff --git a/codex-rs/core/src/tools/code_mode.rs b/codex-rs/core/src/tools/code_mode.rs index d791c53076a..a39543f258f 100644 --- a/codex-rs/core/src/tools/code_mode.rs +++ b/codex-rs/core/src/tools/code_mode.rs @@ -1,5 +1,4 @@ use std::collections::HashMap; -use std::collections::VecDeque; use std::path::PathBuf; use std::sync::Arc; use std::time::Duration; @@ -32,6 +31,8 @@ use tokio::io::AsyncReadExt; use tokio::io::AsyncWriteExt; use tokio::io::BufReader; use tokio::sync::Mutex; +use tokio::sync::mpsc; +use tokio::sync::oneshot; use tokio::task::JoinHandle; use tracing::warn; @@ -50,80 +51,110 @@ struct ExecContext { pub(crate) struct CodeModeProcess { child: tokio::process::Child, - stdin: tokio::process::ChildStdin, - stdout_lines: tokio::io::Lines>, - stderr_task: Option>, - pending_messages: HashMap>, + stdin: Arc>, + stdout_task: JoinHandle<()>, + response_waiters: Arc>>>, + tool_call_rx: Arc>>, } -impl CodeModeProcess { - async fn write(&mut self, message: &HostToNodeMessage) -> Result<(), std::io::Error> { - let line = serde_json::to_string(message).map_err(std::io::Error::other)?; - self.stdin.write_all(line.as_bytes()).await?; - self.stdin.write_all(b"\n").await?; - self.stdin.flush().await?; - Ok(()) - } +pub(crate) struct CodeModeWorker { + shutdown_tx: Option>, + task: JoinHandle<()>, +} + +#[derive(Debug, Deserialize)] +#[serde(rename_all = "snake_case")] +struct CodeModeToolCall { + request_id: String, + id: String, + name: String, + #[serde(default)] + input: Option, +} - async fn read(&mut self, session_id: i32) -> Result { - if let Some(message) = self - .pending_messages - .get_mut(&session_id) - .and_then(VecDeque::pop_front) - { - return Ok(message); +impl Drop for CodeModeWorker { + fn drop(&mut self) { + if let Some(shutdown_tx) = self.shutdown_tx.take() { + let _ = shutdown_tx.send(()); } + } +} - loop { - let Some(line) = self.stdout_lines.next_line().await? else { - match self.wait_for_exit().await { - Ok(status) => { - self.join_stderr_task().await; - return Err(std::io::Error::other(format!( - "{PUBLIC_TOOL_NAME} runner exited without returning a result (status {status})" - ))); - } - Err(err) => return Err(std::io::Error::other(err)), +impl CodeModeProcess { + fn worker(&self, exec: ExecContext) -> CodeModeWorker { + let (shutdown_tx, mut shutdown_rx) = oneshot::channel(); + let stdin = Arc::clone(&self.stdin); + let tool_call_rx = Arc::clone(&self.tool_call_rx); + let task = tokio::spawn(async move { + loop { + let tool_call = tokio::select! { + _ = &mut shutdown_rx => break, + tool_call = async { + let mut tool_call_rx = tool_call_rx.lock().await; + tool_call_rx.recv().await + } => tool_call, + }; + let Some(tool_call) = tool_call else { + break; + }; + let response = HostToNodeMessage::Response { + request_id: tool_call.request_id, + id: tool_call.id, + code_mode_result: call_nested_tool( + exec.clone(), + tool_call.name, + tool_call.input, + ) + .await, + }; + if let Err(err) = write_message(&stdin, &response).await { + warn!("failed to write {PUBLIC_TOOL_NAME} tool response: {err}"); + break; } - }; - if line.trim().is_empty() { - continue; } - let message: NodeToHostMessage = - serde_json::from_str(&line).map_err(std::io::Error::other)?; - let message_session_id = message_session_id(&message); - if message_session_id == session_id { - return Ok(message); - } - self.pending_messages - .entry(message_session_id) - .or_default() - .push_back(message); + }); + + CodeModeWorker { + shutdown_tx: Some(shutdown_tx), + task, } } - fn has_exited(&mut self) -> Result { - self.child - .try_wait() - .map(|status| status.is_some()) - .map_err(|err| format!("failed to inspect {PUBLIC_TOOL_NAME} runner: {err}")) - } + async fn send( + &mut self, + request_id: &str, + message: &HostToNodeMessage, + ) -> Result { + if self.stdout_task.is_finished() { + return Err(std::io::Error::other(format!( + "{PUBLIC_TOOL_NAME} runner is not available" + ))); + } - async fn wait_for_exit(&mut self) -> Result { - self.child - .wait() + let (tx, rx) = oneshot::channel(); + self.response_waiters + .lock() .await - .map_err(|err| format!("failed to wait for {PUBLIC_TOOL_NAME} runner: {err}")) - } + .insert(request_id.to_string(), tx); + if let Err(err) = write_message(&self.stdin, message).await { + self.response_waiters.lock().await.remove(request_id); + return Err(err); + } - async fn join_stderr_task(&mut self) { - let Some(stderr_task) = self.stderr_task.take() else { - return; - }; - if let Err(err) = stderr_task.await { - warn!("failed to join {PUBLIC_TOOL_NAME} stderr task: {err}"); + match rx.await { + Ok(message) => Ok(message), + Err(_) => Err(std::io::Error::other(format!( + "{PUBLIC_TOOL_NAME} runner is not available" + ))), } } + + fn has_exited(&mut self) -> Result { + self.child + .try_wait() + .map(|status| status.is_some()) + .map_err(std::io::Error::other) + } } pub(crate) struct CodeModeService { @@ -153,26 +184,62 @@ impl CodeModeService { async fn ensure_started( &self, - ) -> Result>, String> { + ) -> Result>, std::io::Error> { let mut process_slot = self.process.lock().await; let needs_spawn = match process_slot.as_mut() { Some(process) => !matches!(process.has_exited(), Ok(false)), None => true, }; if needs_spawn { - let node_path = resolve_compatible_node(self.js_repl_node_path.as_deref()).await?; + let node_path = resolve_compatible_node(self.js_repl_node_path.as_deref()) + .await + .map_err(std::io::Error::other)?; *process_slot = Some(spawn_code_mode_process(&node_path).await?); } drop(process_slot); Ok(self.process.clone().lock_owned().await) } + pub(crate) async fn start_turn_worker( + &self, + session: &Arc, + turn: &Arc, + tracker: &SharedTurnDiffTracker, + ) -> Option { + if !turn.features.enabled(Feature::CodeMode) { + return None; + } + let exec = ExecContext { + session: Arc::clone(session), + turn: Arc::clone(turn), + tracker: Arc::clone(tracker), + }; + let mut process_slot = match self.ensure_started().await { + Ok(process_slot) => process_slot, + Err(err) => { + warn!("failed to start {PUBLIC_TOOL_NAME} worker for turn: {err}"); + return None; + } + }; + let Some(process) = process_slot.as_mut() else { + warn!( + "failed to start {PUBLIC_TOOL_NAME} worker for turn: {PUBLIC_TOOL_NAME} runner failed to start" + ); + return None; + }; + Some(process.worker(exec)) + } + pub(crate) async fn allocate_session_id(&self) -> i32 { let mut next_session_id = self.next_session_id.lock().await; let session_id = *next_session_id; *next_session_id = next_session_id.saturating_add(1); session_id } + + pub(crate) async fn allocate_request_id(&self) -> String { + uuid::Uuid::new_v4().to_string() + } } #[derive(Clone, Copy, Debug, Deserialize, Eq, PartialEq, Serialize)] @@ -197,20 +264,23 @@ struct EnabledTool { #[serde(tag = "type", rename_all = "snake_case")] enum HostToNodeMessage { Start { + request_id: String, session_id: i32, enabled_tools: Vec, stored_values: HashMap, source: String, }, Poll { + request_id: String, session_id: i32, yield_time_ms: u64, }, Terminate { + request_id: String, session_id: i32, }, Response { - session_id: i32, + request_id: String, id: String, code_mode_result: JsonValue, }, @@ -220,22 +290,19 @@ enum HostToNodeMessage { #[serde(tag = "type", rename_all = "snake_case")] enum NodeToHostMessage { ToolCall { - session_id: i32, - id: String, - name: String, - #[serde(default)] - input: Option, + #[serde(flatten)] + tool_call: CodeModeToolCall, }, Yielded { - session_id: i32, + request_id: String, content_items: Vec, }, Terminated { - session_id: i32, + request_id: String, content_items: Vec, }, Result { - session_id: i32, + request_id: String, content_items: Vec, stored_values: HashMap, #[serde(default)] @@ -307,10 +374,19 @@ pub(crate) async fn execute( let stored_values = service.stored_values().await; let source = build_source(&code, &enabled_tools).map_err(FunctionCallError::RespondToModel)?; let session_id = service.allocate_session_id().await; + let request_id = service.allocate_request_id().await; let process_slot = service .ensure_started() .await - .map_err(FunctionCallError::RespondToModel)?; + .map_err(|err| FunctionCallError::RespondToModel(err.to_string()))?; + let started_at = std::time::Instant::now(); + let message = HostToNodeMessage::Start { + request_id: request_id.clone(), + session_id, + enabled_tools, + stored_values, + source, + }; let result = { let mut process_slot = process_slot; let Some(process) = process_slot.as_mut() else { @@ -318,19 +394,15 @@ pub(crate) async fn execute( "{PUBLIC_TOOL_NAME} runner failed to start" ))); }; - drive_code_mode_session( - &exec, - process, - HostToNodeMessage::Start { - session_id, - enabled_tools, - stored_values, - source, - }, - None, - false, - ) - .await + let message = process + .send(&request_id, &message) + .await + .map_err(|err| err.to_string()); + let message = match message { + Ok(message) => message, + Err(error) => return Err(FunctionCallError::RespondToModel(error)), + }; + handle_node_message(&exec, session_id, message, None, started_at).await }; match result { Ok(CodeModeSessionProgress::Finished(output)) @@ -353,13 +425,32 @@ pub(crate) async fn wait( turn, tracker, }; + let request_id = exec + .session + .services + .code_mode_service + .allocate_request_id() + .await; + let started_at = std::time::Instant::now(); + let message = if terminate { + HostToNodeMessage::Terminate { + request_id: request_id.clone(), + session_id, + } + } else { + HostToNodeMessage::Poll { + request_id: request_id.clone(), + session_id, + yield_time_ms, + } + }; let process_slot = exec .session .services .code_mode_service .ensure_started() .await - .map_err(FunctionCallError::RespondToModel)?; + .map_err(|err| FunctionCallError::RespondToModel(err.to_string()))?; let result = { let mut process_slot = process_slot; let Some(process) = process_slot.as_mut() else { @@ -372,19 +463,20 @@ pub(crate) async fn wait( "{PUBLIC_TOOL_NAME} runner failed to start" ))); } - drive_code_mode_session( + let message = process + .send(&request_id, &message) + .await + .map_err(|err| err.to_string()); + let message = match message { + Ok(message) => message, + Err(error) => return Err(FunctionCallError::RespondToModel(error)), + }; + handle_node_message( &exec, - process, - if terminate { - HostToNodeMessage::Terminate { session_id } - } else { - HostToNodeMessage::Poll { - session_id, - yield_time_ms, - } - }, + session_id, + message, Some(max_output_tokens), - terminate, + started_at, ) .await }; @@ -395,131 +487,18 @@ pub(crate) async fn wait( } } -async fn spawn_code_mode_process(node_path: &std::path::Path) -> Result { - let mut cmd = tokio::process::Command::new(node_path); - cmd.arg("--experimental-vm-modules"); - cmd.arg("--eval"); - cmd.arg(CODE_MODE_RUNNER_SOURCE); - cmd.stdin(std::process::Stdio::piped()) - .stdout(std::process::Stdio::piped()) - .stderr(std::process::Stdio::piped()) - .kill_on_drop(true); - - let mut child = cmd - .spawn() - .map_err(|err| format!("failed to start {PUBLIC_TOOL_NAME} Node runtime: {err}"))?; - let stdout = child - .stdout - .take() - .ok_or_else(|| format!("{PUBLIC_TOOL_NAME} runner missing stdout"))?; - let stderr = child - .stderr - .take() - .ok_or_else(|| format!("{PUBLIC_TOOL_NAME} runner missing stderr"))?; - let stdin = child - .stdin - .take() - .ok_or_else(|| format!("{PUBLIC_TOOL_NAME} runner missing stdin"))?; - - let stderr_task = tokio::spawn(async move { - let mut reader = BufReader::new(stderr); - let mut buf = Vec::new(); - match reader.read_to_end(&mut buf).await { - Ok(_) => { - let stderr = String::from_utf8_lossy(&buf).trim().to_string(); - if !stderr.is_empty() { - warn!("{PUBLIC_TOOL_NAME} runner stderr: {stderr}"); - } - } - Err(err) => { - warn!("failed to read {PUBLIC_TOOL_NAME} stderr: {err}"); - } - } - }); - - Ok(CodeModeProcess { - child, - stdin, - stdout_lines: BufReader::new(stdout).lines(), - stderr_task: Some(stderr_task), - pending_messages: HashMap::new(), - }) -} - -async fn drive_code_mode_session( - exec: &ExecContext, - process: &mut CodeModeProcess, - message: HostToNodeMessage, - poll_max_output_tokens: Option>, - is_terminate: bool, -) -> Result { - let started_at = std::time::Instant::now(); - let session_id = match &message { - HostToNodeMessage::Start { session_id, .. } - | HostToNodeMessage::Poll { session_id, .. } - | HostToNodeMessage::Terminate { session_id } - | HostToNodeMessage::Response { session_id, .. } => *session_id, - }; - process - .write(&message) - .await - .map_err(|err| err.to_string())?; - - loop { - let message = process - .read(session_id) - .await - .map_err(|err| err.to_string())?; - if let Some(progress) = handle_node_message( - exec, - process, - session_id, - message, - poll_max_output_tokens, - started_at, - is_terminate, - ) - .await? - { - return Ok(progress); - } - } -} - async fn handle_node_message( exec: &ExecContext, - process: &mut CodeModeProcess, session_id: i32, message: NodeToHostMessage, poll_max_output_tokens: Option>, started_at: std::time::Instant, - is_terminate: bool, -) -> Result, String> { +) -> Result { match message { - NodeToHostMessage::ToolCall { - session_id: message_session_id, - id, - name, - input, - } => { - if is_terminate { - return Ok(None); - } - let response = HostToNodeMessage::Response { - session_id: message_session_id, - id, - code_mode_result: call_nested_tool(exec.clone(), name, input).await, - }; - process - .write(&response) - .await - .map_err(|err| err.to_string())?; - Ok(None) - } + NodeToHostMessage::ToolCall { .. } => Err(format!( + "{PUBLIC_TOOL_NAME} received an unexpected tool call response" + )), NodeToHostMessage::Yielded { content_items, .. } => { - if is_terminate { - return Ok(None); - } let mut delta_items = output_content_items_from_json_values(content_items)?; delta_items = truncate_code_mode_result(delta_items, poll_max_output_tokens.flatten()); prepend_script_status( @@ -527,9 +506,9 @@ async fn handle_node_message( CodeModeExecutionStatus::Running(session_id), started_at.elapsed(), ); - Ok(Some(CodeModeSessionProgress::Yielded { + Ok(CodeModeSessionProgress::Yielded { output: FunctionToolOutput::from_content(delta_items, Some(true)), - })) + }) } NodeToHostMessage::Terminated { content_items, .. } => { let mut delta_items = output_content_items_from_json_values(content_items)?; @@ -539,9 +518,9 @@ async fn handle_node_message( CodeModeExecutionStatus::Terminated, started_at.elapsed(), ); - Ok(Some(CodeModeSessionProgress::Finished( + Ok(CodeModeSessionProgress::Finished( FunctionToolOutput::from_content(delta_items, Some(true)), - ))) + )) } NodeToHostMessage::Result { content_items, @@ -576,19 +555,127 @@ async fn handle_node_message( }, started_at.elapsed(), ); - Ok(Some(CodeModeSessionProgress::Finished( + Ok(CodeModeSessionProgress::Finished( FunctionToolOutput::from_content(delta_items, Some(success)), - ))) + )) } } } -fn message_session_id(message: &NodeToHostMessage) -> i32 { +async fn spawn_code_mode_process( + node_path: &std::path::Path, +) -> Result { + let mut cmd = tokio::process::Command::new(node_path); + cmd.arg("--experimental-vm-modules"); + cmd.arg("--eval"); + cmd.arg(CODE_MODE_RUNNER_SOURCE); + cmd.stdin(std::process::Stdio::piped()) + .stdout(std::process::Stdio::piped()) + .stderr(std::process::Stdio::piped()) + .kill_on_drop(true); + + let mut child = cmd.spawn().map_err(std::io::Error::other)?; + let stdout = child.stdout.take().ok_or_else(|| { + std::io::Error::other(format!("{PUBLIC_TOOL_NAME} runner missing stdout")) + })?; + let stderr = child.stderr.take().ok_or_else(|| { + std::io::Error::other(format!("{PUBLIC_TOOL_NAME} runner missing stderr")) + })?; + let stdin = child + .stdin + .take() + .ok_or_else(|| std::io::Error::other(format!("{PUBLIC_TOOL_NAME} runner missing stdin")))?; + let stdin = Arc::new(Mutex::new(stdin)); + let response_waiters = Arc::new(Mutex::new(HashMap::< + String, + oneshot::Sender, + >::new())); + let (tool_call_tx, tool_call_rx) = mpsc::unbounded_channel(); + + tokio::spawn(async move { + let mut reader = BufReader::new(stderr); + let mut buf = Vec::new(); + match reader.read_to_end(&mut buf).await { + Ok(_) => { + let stderr = String::from_utf8_lossy(&buf).trim().to_string(); + if !stderr.is_empty() { + warn!("{PUBLIC_TOOL_NAME} runner stderr: {stderr}"); + } + } + Err(err) => { + warn!("failed to read {PUBLIC_TOOL_NAME} stderr: {err}"); + } + } + }); + let stdout_task = tokio::spawn({ + let response_waiters = Arc::clone(&response_waiters); + let tool_call_tx = tool_call_tx.clone(); + async move { + let mut stdout_lines = BufReader::new(stdout).lines(); + loop { + let line = match stdout_lines.next_line().await { + Ok(line) => line, + Err(err) => { + warn!("failed to read {PUBLIC_TOOL_NAME} stdout: {err}"); + break; + } + }; + let Some(line) = line else { + break; + }; + if line.trim().is_empty() { + continue; + } + let message: NodeToHostMessage = match serde_json::from_str(&line) { + Ok(message) => message, + Err(err) => { + warn!("failed to parse {PUBLIC_TOOL_NAME} stdout message: {err}"); + break; + } + }; + match message { + NodeToHostMessage::ToolCall { tool_call } => { + let _ = tool_call_tx.send(tool_call); + } + message => { + let request_id = message_request_id(&message).to_string(); + if let Some(waiter) = response_waiters.lock().await.remove(&request_id) { + let _ = waiter.send(message); + } + } + } + } + response_waiters.lock().await.clear(); + } + }); + + Ok(CodeModeProcess { + child, + stdin, + stdout_task, + response_waiters, + tool_call_rx: Arc::new(Mutex::new(tool_call_rx)), + }) +} + +async fn write_message( + stdin: &Arc>, + message: &HostToNodeMessage, +) -> Result<(), std::io::Error> { + let line = serde_json::to_string(message).map_err(std::io::Error::other)?; + let mut stdin = stdin.lock().await; + stdin.write_all(line.as_bytes()).await?; + stdin.write_all(b"\n").await?; + stdin.flush().await?; + Ok(()) +} + +fn message_request_id(message: &NodeToHostMessage) -> &str { match message { - NodeToHostMessage::ToolCall { session_id, .. } - | NodeToHostMessage::Yielded { session_id, .. } - | NodeToHostMessage::Terminated { session_id, .. } - | NodeToHostMessage::Result { session_id, .. } => *session_id, + NodeToHostMessage::ToolCall { tool_call } => &tool_call.request_id, + NodeToHostMessage::Yielded { request_id, .. } + | NodeToHostMessage::Terminated { request_id, .. } + | NodeToHostMessage::Result { request_id, .. } => request_id, } } diff --git a/codex-rs/core/src/tools/code_mode_runner.cjs b/codex-rs/core/src/tools/code_mode_runner.cjs index d64e369f320..c7ed8fdebde 100644 --- a/codex-rs/core/src/tools/code_mode_runner.cjs +++ b/codex-rs/core/src/tools/code_mode_runner.cjs @@ -466,11 +466,18 @@ function createProtocol() { if (message.type === 'poll') { const session = sessions.get(message.session_id); if (session) { - schedulePollYield(protocol, session, normalizeYieldTime(message.yield_time_ms ?? 0)); + session.request_id = String(message.request_id); + if (session.pending_result) { + void completeSession(protocol, sessions, session, session.pending_result); + } else if (session.pending_tool_call) { + void forwardToolCall(protocol, session, session.pending_tool_call); + } else { + schedulePollYield(protocol, session, normalizeYieldTime(message.yield_time_ms ?? 0)); + } } else { void protocol.send({ type: 'result', - session_id: message.session_id, + request_id: message.request_id, content_items: [], stored_values: {}, error_text: `exec session ${message.session_id} not found`, @@ -483,11 +490,12 @@ function createProtocol() { if (message.type === 'terminate') { const session = sessions.get(message.session_id); if (session) { + session.request_id = String(message.request_id); void terminateSession(protocol, sessions, session); } else { void protocol.send({ type: 'result', - session_id: message.session_id, + request_id: message.request_id, content_items: [], stored_values: {}, error_text: `exec session ${message.session_id} not found`, @@ -498,11 +506,11 @@ function createProtocol() { } if (message.type === 'response') { - const entry = pending.get(message.session_id + ':' + message.id); + const entry = pending.get(message.request_id + ':' + message.id); if (!entry) { return; } - pending.delete(message.session_id + ':' + message.id); + pending.delete(message.request_id + ':' + message.id); entry.resolve(message.code_mode_result ?? ''); return; } @@ -537,12 +545,12 @@ function createProtocol() { }); } - function request(sessionId, type, payload) { + function request(requestId, type, payload) { const id = 'msg-' + ++nextId; - const pendingKey = sessionId + ':' + id; + const pendingKey = requestId + ':' + id; return new Promise((resolve, reject) => { pending.set(pendingKey, { resolve, reject }); - void send({ type, session_id: sessionId, id, ...payload }).catch((error) => { + void send({ type, request_id: requestId, id, ...payload }).catch((error) => { pending.delete(pendingKey); reject(error); }); @@ -565,7 +573,10 @@ function startSession(protocol, sessions, start) { initial_yield_timer: null, initial_yield_triggered: false, max_output_tokens_per_exec_call: DEFAULT_MAX_OUTPUT_TOKENS_PER_EXEC_CALL, + pending_result: null, + pending_tool_call: null, poll_yield_timer: null, + request_id: String(start.request_id), worker: new Worker(sessionWorkerSource(), { eval: true, workerData: start, @@ -621,17 +632,28 @@ async function handleWorkerMessage(protocol, sessions, session, message) { } if (message.type === 'tool_call') { + if (session.request_id === null) { + session.pending_tool_call = message; + return; + } void forwardToolCall(protocol, session, message); return; } if (message.type === 'result') { - await completeSession(protocol, sessions, session, { + const result = { type: 'result', stored_values: cloneJsonValue(message.stored_values ?? {}), error_text: typeof message.error_text === 'string' ? message.error_text : undefined, - }); + }; + if (session.request_id === null) { + session.pending_result = result; + session.initial_yield_timer = clearTimer(session.initial_yield_timer); + session.poll_yield_timer = clearTimer(session.poll_yield_timer); + return; + } + await completeSession(protocol, sessions, session, result); return; } @@ -640,10 +662,11 @@ async function handleWorkerMessage(protocol, sessions, session, message) { async function forwardToolCall(protocol, session, message) { try { - const result = await protocol.request(session.id, 'tool_call', { + const result = await protocol.request(session.request_id, 'tool_call', { name: String(message.name), input: message.input, }); + session.pending_tool_call = null; if (session.completed) { return; } @@ -655,6 +678,7 @@ async function forwardToolCall(protocol, session, message) { }); } catch {} } catch (error) { + session.pending_tool_call = null; if (session.completed) { return; } @@ -673,14 +697,16 @@ async function sendYielded(protocol, session) { return; } const contentItems = takeContentItems(session); + const requestId = session.request_id; try { session.worker.postMessage({ type: 'clear_content' }); } catch {} await protocol.send({ type: 'yielded', - session_id: session.id, + request_id: requestId, content_items: contentItems, }); + session.request_id = null; } function scheduleInitialYield(protocol, session, yieldTime) { @@ -711,17 +737,26 @@ async function completeSession(protocol, sessions, session, message) { if (session.completed) { return; } + if (session.request_id === null) { + session.pending_result = message; + session.initial_yield_timer = clearTimer(session.initial_yield_timer); + session.poll_yield_timer = clearTimer(session.poll_yield_timer); + return; + } + const requestId = session.request_id; session.completed = true; session.initial_yield_timer = clearTimer(session.initial_yield_timer); session.poll_yield_timer = clearTimer(session.poll_yield_timer); sessions.delete(session.id); const contentItems = takeContentItems(session); + session.pending_result = null; + session.pending_tool_call = null; try { session.worker.postMessage({ type: 'clear_content' }); } catch {} await protocol.send({ ...message, - session_id: session.id, + request_id: requestId, content_items: contentItems, max_output_tokens_per_exec_call: session.max_output_tokens_per_exec_call, }); @@ -741,7 +776,7 @@ async function terminateSession(protocol, sessions, session) { } catch {} await protocol.send({ type: 'terminated', - session_id: session.id, + request_id: session.request_id, content_items: contentItems, }); }