Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
28 changes: 26 additions & 2 deletions codex-rs/core/src/session/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2511,7 +2511,22 @@ impl Session {
turn_context: &TurnContext,
items: &[ResponseItem],
) {
self.record_into_history(items, turn_context).await;
self.record_conversation_items_with_history_policy(
turn_context,
items,
turn_context.truncation_policy,
)
.await;
}

pub(crate) async fn record_conversation_items_with_history_policy(
&self,
turn_context: &TurnContext,
items: &[ResponseItem],
history_truncation_policy: TruncationPolicy,
) {
self.record_into_history_with_policy(items, history_truncation_policy)
.await;
self.persist_rollout_response_items(items).await;
self.send_raw_response_items(turn_context, items).await;
}
Expand All @@ -2521,9 +2536,18 @@ impl Session {
&self,
items: &[ResponseItem],
turn_context: &TurnContext,
) {
self.record_into_history_with_policy(items, turn_context.truncation_policy)
.await;
}

pub(crate) async fn record_into_history_with_policy(
&self,
items: &[ResponseItem],
history_truncation_policy: TruncationPolicy,
) {
let mut state = self.state.lock().await;
state.record_items(items.iter(), turn_context.truncation_policy);
state.record_items(items.iter(), history_truncation_policy);
}

async fn maybe_warn_on_server_model_mismatch(
Expand Down
20 changes: 13 additions & 7 deletions codex-rs/core/src/session/turn.rs
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,7 @@ use crate::stream_events_utils::record_completed_response_item_with_finalized_fa
use crate::tools::ToolRouter;
use crate::tools::context::SharedTurnDiffTracker;
use crate::tools::parallel::ToolCallRuntime;
use crate::tools::registry::RecordedToolResponse;
use crate::tools::registry::ToolArgumentDiffConsumer;
use crate::tools::router::ToolRouterParams;
use crate::tools::router::extension_tool_executors;
Expand Down Expand Up @@ -81,7 +82,6 @@ use codex_protocol::items::build_hook_prompt_message;
use codex_protocol::models::BaseInstructions;
use codex_protocol::models::ContentItem;
use codex_protocol::models::MessagePhase;
use codex_protocol::models::ResponseInputItem;
use codex_protocol::models::ResponseItem;
use codex_protocol::protocol::AgentMessageContentDeltaEvent;
use codex_protocol::protocol::AgentReasoningSectionBreakEvent;
Expand Down Expand Up @@ -1676,16 +1676,22 @@ async fn handle_assistant_item_done_in_plan_mode(
}

async fn drain_in_flight(
in_flight: &mut FuturesOrdered<BoxFuture<'static, CodexResult<ResponseInputItem>>>,
in_flight: &mut FuturesOrdered<BoxFuture<'static, CodexResult<RecordedToolResponse>>>,
sess: Arc<Session>,
turn_context: Arc<TurnContext>,
) -> CodexResult<()> {
while let Some(res) = in_flight.next().await {
match res {
Ok(response_input) => {
let response_item = response_input.into();
sess.record_conversation_items(&turn_context, std::slice::from_ref(&response_item))
.await;
Ok(recorded_tool_response) => {
let response_item = recorded_tool_response.response_item.into();
sess.record_conversation_items_with_history_policy(
&turn_context,
std::slice::from_ref(&response_item),
recorded_tool_response
.history_truncation_policy
.unwrap_or(turn_context.truncation_policy),
)
.await;
mark_thread_memory_mode_polluted_if_external_context(
sess.as_ref(),
turn_context.as_ref(),
Expand Down Expand Up @@ -1747,7 +1753,7 @@ async fn try_run_sampling_request(
.instrument(trace_span!("stream_request"))
.or_cancel(&cancellation_token)
.await??;
let mut in_flight: FuturesOrdered<BoxFuture<'static, CodexResult<ResponseInputItem>>> =
let mut in_flight: FuturesOrdered<BoxFuture<'static, CodexResult<RecordedToolResponse>>> =
FuturesOrdered::new();
let mut needs_follow_up = false;
let mut last_agent_message: Option<String> = None;
Expand Down
2 changes: 1 addition & 1 deletion codex-rs/core/src/stream_events_utils.rs
Original file line number Diff line number Diff line change
Expand Up @@ -242,7 +242,7 @@ async fn record_stage1_output_usage_for_memory_citation(
/// queuing any tool execution futures. This records items immediately so
/// history and rollout stay in sync even if the turn is later cancelled.
pub(crate) type InFlightFuture<'f> =
Pin<Box<dyn Future<Output = Result<ResponseInputItem>> + Send + 'f>>;
Pin<Box<dyn Future<Output = Result<crate::tools::registry::RecordedToolResponse>> + Send + 'f>>;

#[derive(Default)]
pub(crate) struct OutputItemResult {
Expand Down
12 changes: 12 additions & 0 deletions codex-rs/core/src/tools/code_mode/execute_handler.rs
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@ use codex_tools::ToolSpec;

use super::ExecContext;
use super::PUBLIC_TOOL_NAME;
use super::code_mode_output_truncation_policy;
use super::handle_runtime_response;
use super::is_exec_tool_name;

Expand Down Expand Up @@ -127,4 +128,15 @@ impl CoreToolRuntime for CodeModeExecuteHandler {
fn matches_kind(&self, payload: &ToolPayload) -> bool {
matches!(payload, ToolPayload::Custom { .. })
}

fn history_truncation_policy(
&self,
invocation: &ToolInvocation,
) -> Option<codex_utils_output_truncation::TruncationPolicy> {
let ToolPayload::Custom { input } = &invocation.payload else {
return None;
};
let args = codex_code_mode::parse_exec_source(input).ok()?;
Some(code_mode_output_truncation_policy(args.max_output_tokens))
}
}
9 changes: 7 additions & 2 deletions codex-rs/core/src/tools/code_mode/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -243,8 +243,7 @@ fn truncate_code_mode_result(
items: Vec<FunctionCallOutputContentItem>,
max_output_tokens: Option<usize>,
) -> Vec<FunctionCallOutputContentItem> {
let max_output_tokens = resolve_max_tokens(max_output_tokens);
let policy = TruncationPolicy::Tokens(max_output_tokens);
let policy = code_mode_output_truncation_policy(max_output_tokens);
if items
.iter()
.all(|item| matches!(item, FunctionCallOutputContentItem::InputText { .. }))
Expand All @@ -257,6 +256,12 @@ fn truncate_code_mode_result(
truncate_function_output_items_with_policy(&items, policy)
}

pub(super) fn code_mode_output_truncation_policy(
max_output_tokens: Option<usize>,
) -> TruncationPolicy {
TruncationPolicy::Tokens(resolve_max_tokens(max_output_tokens))
}

async fn call_nested_tool(
_exec: ExecContext,
tool_runtime: ToolCallRuntime,
Expand Down
14 changes: 13 additions & 1 deletion codex-rs/core/src/tools/code_mode/wait_handler.rs
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@ use codex_tools::ToolSpec;
use super::DEFAULT_WAIT_YIELD_TIME_MS;
use super::ExecContext;
use super::WAIT_TOOL_NAME;
use super::code_mode_output_truncation_policy;
use super::handle_runtime_response;
use super::wait_spec::create_wait_tool;

Expand Down Expand Up @@ -110,4 +111,15 @@ impl ToolExecutor<ToolInvocation> for CodeModeWaitHandler {
}
}

impl CoreToolRuntime for CodeModeWaitHandler {}
impl CoreToolRuntime for CodeModeWaitHandler {
fn history_truncation_policy(
&self,
invocation: &ToolInvocation,
) -> Option<codex_utils_output_truncation::TruncationPolicy> {
let ToolPayload::Function { arguments } = &invocation.payload else {
return None;
};
let args: ExecWaitArgs = parse_arguments(arguments).ok()?;
Some(code_mode_output_truncation_policy(args.max_tokens))
}
}
16 changes: 11 additions & 5 deletions codex-rs/core/src/tools/parallel.rs
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ use crate::tools::context::SharedTurnDiffTracker;
use crate::tools::context::ToolPayload;
use crate::tools::lifecycle::notify_tool_aborted;
use crate::tools::registry::AnyToolResult;
use crate::tools::registry::RecordedToolResponse;
use crate::tools::registry::ToolArgumentDiffConsumer;
use crate::tools::router::ToolCall;
use crate::tools::router::ToolCallSource;
Expand Down Expand Up @@ -64,13 +65,13 @@ impl ToolCallRuntime {
self,
call: ToolCall,
cancellation_token: CancellationToken,
) -> impl std::future::Future<Output = Result<ResponseInputItem, CodexErr>> {
) -> impl std::future::Future<Output = Result<RecordedToolResponse, CodexErr>> {
let error_call = call.clone();
let future =
self.handle_tool_call_with_source(call, ToolCallSource::Direct, cancellation_token);
async move {
match future.await {
Ok(response) => Ok(response.into_response()),
Ok(response) => Ok(response.into_recorded_response()),
Err(FunctionCallError::Fatal(message)) => Err(CodexErr::Fatal(message)),
Err(other) => Ok(Self::failure_response(error_call, other)),
}
Expand Down Expand Up @@ -170,9 +171,9 @@ impl ToolCallRuntime {
FunctionCallError::Fatal(format!("tool task failed to receive: {err:?}"))
}

fn failure_response(call: ToolCall, err: FunctionCallError) -> ResponseInputItem {
fn failure_response(call: ToolCall, err: FunctionCallError) -> RecordedToolResponse {
let message = err.to_string();
match call.payload {
let response_item = match call.payload {
ToolPayload::ToolSearch { .. } => ResponseInputItem::ToolSearchOutput {
call_id: call.call_id,
status: "completed".to_string(),
Expand All @@ -194,6 +195,10 @@ impl ToolCallRuntime {
success: Some(false),
},
},
};
RecordedToolResponse {
response_item,
history_truncation_policy: None,
}
}

Expand All @@ -205,6 +210,7 @@ impl ToolCallRuntime {
message: Self::abort_message(call, secs),
}),
post_tool_use_payload: None,
history_truncation_policy: None,
}
}

Expand Down Expand Up @@ -353,7 +359,7 @@ mod tests {
success: Some(true),
},
};
assert_eq!(expected_response, response);
assert_eq!(expected_response, response.response_item);

let actual = records
.lock()
Expand Down
32 changes: 32 additions & 0 deletions codex-rs/core/src/tools/registry.rs
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@ use codex_protocol::models::ResponseInputItem;
use codex_protocol::protocol::EventMsg;
use codex_tools::ToolName;
use codex_tools::ToolSpec;
use codex_utils_output_truncation::TruncationPolicy;
use futures::future::BoxFuture;
use serde_json::Value;
use tracing::warn;
Expand Down Expand Up @@ -70,6 +71,10 @@ pub(crate) trait CoreToolRuntime: ToolExecutor<ToolInvocation> {
None
}

fn history_truncation_policy(&self, _invocation: &ToolInvocation) -> Option<TruncationPolicy> {
None
}

fn pre_tool_use_payload(&self, _invocation: &ToolInvocation) -> Option<PreToolUsePayload> {
None
}
Expand Down Expand Up @@ -112,9 +117,16 @@ pub(crate) struct AnyToolResult {
pub(crate) payload: ToolPayload,
pub(crate) result: Box<dyn ToolOutput>,
pub(crate) post_tool_use_payload: Option<PostToolUsePayload>,
pub(crate) history_truncation_policy: Option<TruncationPolicy>,
}

pub(crate) struct RecordedToolResponse {
pub(crate) response_item: ResponseInputItem,
pub(crate) history_truncation_policy: Option<TruncationPolicy>,
}

impl AnyToolResult {
#[cfg(test)]
pub(crate) fn into_response(self) -> ResponseInputItem {
let Self {
call_id,
Expand All @@ -125,6 +137,20 @@ impl AnyToolResult {
result.to_response_item(&call_id, &payload)
}

pub(crate) fn into_recorded_response(self) -> RecordedToolResponse {
let Self {
call_id,
payload,
result,
history_truncation_policy,
..
} = self;
RecordedToolResponse {
response_item: result.to_response_item(&call_id, &payload),
history_truncation_policy,
}
}

pub(crate) fn code_mode_result(self) -> serde_json::Value {
let Self {
payload, result, ..
Expand Down Expand Up @@ -225,6 +251,10 @@ impl CoreToolRuntime for ExposureOverride {
self.handler.post_tool_use_payload(invocation, result)
}

fn history_truncation_policy(&self, invocation: &ToolInvocation) -> Option<TruncationPolicy> {
self.handler.history_truncation_policy(invocation)
}

fn with_updated_hook_input(
&self,
invocation: ToolInvocation,
Expand Down Expand Up @@ -610,11 +640,13 @@ async fn handle_any_tool(
let output = tool.handle(invocation.clone()).await?;
let post_tool_use_payload =
CoreToolRuntime::post_tool_use_payload(tool, &invocation, output.as_ref());
let history_truncation_policy = CoreToolRuntime::history_truncation_policy(tool, &invocation);
Ok(AnyToolResult {
call_id,
payload,
result: output,
post_tool_use_payload,
history_truncation_policy,
})
}

Expand Down
Loading