From f0a216c39bd7574f6e17a112eaba8cb671ca18fe Mon Sep 17 00:00:00 2001 From: Albin Cassirer Date: Mon, 27 Apr 2026 10:16:26 -0700 Subject: [PATCH 1/4] Trace cancelled inference streams --- codex-rs/core/src/client.rs | 36 +++++- codex-rs/core/src/client_common.rs | 10 ++ codex-rs/core/src/client_tests.rs | 103 +++++++++++++++++ codex-rs/rollout-trace/src/inference.rs | 109 +++++++++++++----- codex-rs/rollout-trace/src/raw_event.rs | 11 ++ .../rollout-trace/src/reducer/inference.rs | 33 ++++++ .../src/reducer/inference_tests.rs | 92 +++++++++++++++ codex-rs/rollout-trace/src/reducer/mod.rs | 14 +++ codex-rs/rollout-trace/src/reducer/thread.rs | 6 + codex-rs/rollout-trace/src/thread_tests.rs | 2 +- 10 files changed, 383 insertions(+), 33 deletions(-) create mode 100644 codex-rs/rollout-trace/src/reducer/inference_tests.rs diff --git a/codex-rs/core/src/client.rs b/codex-rs/core/src/client.rs index c49e28f20a81..b57d9e074877 100644 --- a/codex-rs/core/src/client.rs +++ b/codex-rs/core/src/client.rs @@ -98,6 +98,7 @@ use tokio::sync::oneshot; use tokio::sync::oneshot::error::TryRecvError; use tokio_tungstenite::tungstenite::Error; use tokio_tungstenite::tungstenite::Message; +use tokio_util::sync::CancellationToken; use tracing::instrument; use tracing::trace; use tracing::warn; @@ -1232,7 +1233,7 @@ impl ModelClientSession { Err(ApiError::Transport( unauthorized_transport @ TransportError::Http { status, .. }, )) if status == StatusCode::UNAUTHORIZED => { - inference_trace_attempt.record_failed(&unauthorized_transport); + inference_trace_attempt.record_failed(&unauthorized_transport, &[]); pending_retry = PendingUnauthorizedRetry::from_recovery( handle_unauthorized( unauthorized_transport, @@ -1245,7 +1246,7 @@ impl ModelClientSession { } Err(err) => { let err = map_api_error(err); - inference_trace_attempt.record_failed(&err); + inference_trace_attempt.record_failed(&err, &[]); return Err(err); } } @@ -1372,7 +1373,7 @@ impl ModelClientSession { .await .map_err(|err| { let err = map_api_error(err); - inference_trace_attempt.record_failed(&err); + inference_trace_attempt.record_failed(&err, &[]); err })?; let (stream, last_request_rx) = map_response_stream( @@ -1634,13 +1635,28 @@ where { let (tx_event, rx_event) = mpsc::channel::>(1600); let (tx_last_response, rx_last_response) = oneshot::channel::(); + let consumer_dropped = CancellationToken::new(); + let consumer_dropped_for_stream = consumer_dropped.clone(); tokio::spawn(async move { let mut logged_error = false; let mut tx_last_response = Some(tx_last_response); let mut items_added: Vec = Vec::new(); let mut api_stream = api_stream; - while let Some(event) = api_stream.next().await { + loop { + let event = tokio::select! { + _ = consumer_dropped.cancelled() => { + inference_trace_attempt.record_cancelled( + "response stream dropped before provider terminal event", + &items_added, + ); + return; + } + event = api_stream.next() => event, + }; + let Some(event) = event else { + break; + }; match event { Ok(ResponseEvent::OutputItemDone(item)) => { items_added.push(item.clone()); @@ -1696,7 +1712,7 @@ where } Err(err) => { let mapped = map_api_error(err); - inference_trace_attempt.record_failed(&mapped); + inference_trace_attempt.record_failed(&mapped, &items_added); if !logged_error { session_telemetry.see_event_completed_failed(&mapped); logged_error = true; @@ -1707,9 +1723,17 @@ where } } } + inference_trace_attempt + .record_failed("stream closed before response.completed", &items_added); }); - (ResponseStream { rx_event }, rx_last_response) + ( + ResponseStream { + rx_event, + consumer_dropped: consumer_dropped_for_stream, + }, + rx_last_response, + ) } /// Handles a 401 response by optionally refreshing ChatGPT tokens once. diff --git a/codex-rs/core/src/client_common.rs b/codex-rs/core/src/client_common.rs index e8e37540033f..efe2670652b1 100644 --- a/codex-rs/core/src/client_common.rs +++ b/codex-rs/core/src/client_common.rs @@ -13,6 +13,7 @@ use std::pin::Pin; use std::task::Context; use std::task::Poll; use tokio::sync::mpsc; +use tokio_util::sync::CancellationToken; /// Review thread system prompt. Edit `core/src/review_prompt.md` to customize. pub const REVIEW_PROMPT: &str = include_str!("../review_prompt.md"); @@ -175,6 +176,9 @@ fn strip_total_output_header(output: &str) -> Option<(&str, u32)> { pub struct ResponseStream { pub(crate) rx_event: mpsc::Receiver>, + /// Signals the mapper task that the consumer stopped polling before the + /// provider stream reached its own terminal event. + pub(crate) consumer_dropped: CancellationToken, } impl Stream for ResponseStream { @@ -185,6 +189,12 @@ impl Stream for ResponseStream { } } +impl Drop for ResponseStream { + fn drop(&mut self) { + self.consumer_dropped.cancel(); + } +} + #[cfg(test)] #[path = "client_common_tests.rs"] mod tests; diff --git a/codex-rs/core/src/client_tests.rs b/codex-rs/core/src/client_tests.rs index f4575b26a0b2..4e56d3d90739 100644 --- a/codex-rs/core/src/client_tests.rs +++ b/codex-rs/core/src/client_tests.rs @@ -7,17 +7,29 @@ use super::X_CODEX_PARENT_THREAD_ID_HEADER; use super::X_CODEX_TURN_METADATA_HEADER; use super::X_CODEX_WINDOW_ID_HEADER; use super::X_OPENAI_SUBAGENT_HEADER; +use codex_api::ResponseEvent; use codex_app_server_protocol::AuthMode; use codex_model_provider::BearerAuthProvider; use codex_model_provider_info::WireApi; use codex_model_provider_info::create_oss_provider_with_base_url; use codex_otel::SessionTelemetry; use codex_protocol::ThreadId; +use codex_protocol::models::ContentItem; +use codex_protocol::models::ResponseItem; use codex_protocol::openai_models::ModelInfo; use codex_protocol::protocol::SessionSource; use codex_protocol::protocol::SubAgentSource; +use codex_rollout_trace::ExecutionStatus; +use codex_rollout_trace::InferenceTraceContext; +use codex_rollout_trace::RawTraceEventPayload; +use codex_rollout_trace::TraceWriter; +use codex_rollout_trace::replay_bundle; +use futures::StreamExt; use pretty_assertions::assert_eq; use serde_json::json; +use std::sync::Arc; +use std::time::Duration; +use tempfile::TempDir; fn test_model_client(session_source: SessionSource) -> ModelClient { let provider = create_oss_provider_with_base_url("https://example.com/v1", WireApi::Responses); @@ -151,6 +163,97 @@ async fn summarize_memories_returns_empty_for_empty_input() { assert_eq!(output.len(), 0); } +#[tokio::test] +async fn dropped_response_stream_traces_cancelled_partial_output() -> anyhow::Result<()> { + let temp = TempDir::new()?; + let writer = Arc::new(TraceWriter::create( + temp.path(), + "trace-1".to_string(), + "rollout-1".to_string(), + "thread-root".to_string(), + )?); + writer.append(RawTraceEventPayload::ThreadStarted { + thread_id: "thread-root".to_string(), + agent_path: "/root".to_string(), + metadata_payload: None, + })?; + writer.append(RawTraceEventPayload::CodexTurnStarted { + codex_turn_id: "turn-1".to_string(), + thread_id: "thread-root".to_string(), + })?; + + let inference_trace = InferenceTraceContext::enabled( + writer, + "thread-root".to_string(), + "turn-1".to_string(), + "gpt-test".to_string(), + "test-provider".to_string(), + ); + let attempt = inference_trace.start_attempt(); + attempt.record_started(&json!({ + "model": "gpt-test", + "input": [{ + "type": "message", + "role": "user", + "content": [{"type": "input_text", "text": "hello"}] + }], + })); + + // The provider has produced one complete output item, but no terminal + // response.completed event. The harness has enough information to keep this + // item in history, so the trace should preserve it when the stream is + // abandoned. + let item = ResponseItem::Message { + id: Some("msg-1".to_string()), + role: "assistant".to_string(), + content: vec![ContentItem::OutputText { + text: "partial answer".to_string(), + }], + phase: None, + }; + let api_stream = futures::stream::iter([Ok(ResponseEvent::OutputItemDone(item))]) + .chain(futures::stream::pending()); + let (mut stream, _) = super::map_response_stream(api_stream, test_session_telemetry(), attempt); + + let observed = stream + .next() + .await + .expect("mapped stream should yield output item")?; + assert!(matches!(observed, ResponseEvent::OutputItemDone(_))); + + // Dropping the consumer is how turn interruption/preemption stops polling + // the provider stream. The mapper task observes that drop asynchronously + // and records cancellation using the output items it has already seen. + drop(stream); + + // Cancellation is recorded by the mapper task after Drop wakes it, so the + // replay may need a short wait before the terminal event appears on disk. + let mut rollout = replay_bundle(temp.path())?; + for _ in 0..50 { + let inference = rollout + .inference_calls + .values() + .next() + .expect("inference should be reduced"); + if inference.execution.status == ExecutionStatus::Cancelled { + break; + } + tokio::time::sleep(Duration::from_millis(10)).await; + rollout = replay_bundle(temp.path())?; + } + let inference = rollout + .inference_calls + .values() + .next() + .expect("inference should be reduced"); + + assert_eq!(inference.execution.status, ExecutionStatus::Cancelled); + assert_eq!(inference.response_item_ids.len(), 1); + assert_eq!(rollout.raw_payloads.len(), 2); + + Ok(()) +} + #[test] fn auth_request_telemetry_context_tracks_attached_auth_and_retry_phase() { let auth_context = AuthRequestTelemetryContext::new( diff --git a/codex-rs/rollout-trace/src/inference.rs b/codex-rs/rollout-trace/src/inference.rs index 935a2af6c920..3f76cbdf6244 100644 --- a/codex-rs/rollout-trace/src/inference.rs +++ b/codex-rs/rollout-trace/src/inference.rs @@ -6,6 +6,7 @@ use std::fmt::Display; use std::sync::Arc; +use std::sync::atomic::AtomicBool; use std::sync::atomic::AtomicU64; use std::sync::atomic::Ordering; @@ -54,35 +55,36 @@ struct EnabledInferenceTraceContext { /// /// A Codex turn can create multiple attempts when auth recovery retries the /// HTTP request or WebSocket setup falls back to HTTP. Completion is often -/// observed after the client returns the response stream, so attempts are -/// cloneable and self-contained. -#[derive(Clone, Debug)] +/// observed after the client returns the response stream, so the attempt owns +/// the terminal guard that prevents duplicate lifecycle events. +#[derive(Debug)] pub struct InferenceTraceAttempt { state: InferenceTraceAttemptState, } -#[derive(Clone, Debug)] +#[derive(Debug)] enum InferenceTraceAttemptState { Disabled, Enabled(EnabledInferenceTraceAttempt), } -#[derive(Clone, Debug)] +#[derive(Debug)] struct EnabledInferenceTraceAttempt { context: EnabledInferenceTraceContext, inference_call_id: InferenceCallId, + terminal_recorded: AtomicBool, } -/// Non-delta response payload saved when a traced inference stream completes. +/// Non-delta response payload saved for completed or interrupted inference streams. /// /// We intentionally record completed output items instead of every stream delta /// here. The raw stream can be added later as a separate payload class; this -/// response summary gives the reducer stable response identity, usage, and -/// model-visible output without duplicating high-volume text deltas. +/// response summary gives the reducer stable response identity when available +/// plus model-visible output without duplicating high-volume text deltas. #[derive(Serialize)] -struct TracedResponseStreamCompleted<'a> { - response_id: &'a str, - token_usage: &'a Option, +struct TracedResponseStreamOutput<'a> { + response_id: Option<&'a str>, + token_usage: Option<&'a TokenUsage>, output_items: Vec, } @@ -123,6 +125,7 @@ impl InferenceTraceContext { state: InferenceTraceAttemptState::Enabled(EnabledInferenceTraceAttempt { context: context.clone(), inference_call_id: next_inference_call_id(), + terminal_recorded: AtomicBool::new(false), }), } } @@ -162,9 +165,9 @@ impl InferenceTraceAttempt { ); } - /// Records a bounded, non-streaming summary of the completed response stream. + /// Records successful provider completion and serializes the observed output items. /// - /// The caller passes protocol-native response items so this crate owns the + /// Callers pass protocol-native response items so this crate owns the /// trace-specific serialization rules. That keeps codex-core focused on /// transport behavior while preserving trace evidence that normal request /// serialization intentionally omits. @@ -174,18 +177,14 @@ impl InferenceTraceAttempt { token_usage: &Option, output_items: &[ResponseItem], ) { - let InferenceTraceAttemptState::Enabled(attempt) = &self.state else { + let Some(attempt) = self.take_terminal_attempt() else { return; }; - let response_payload = TracedResponseStreamCompleted { - response_id, - token_usage, - output_items: output_items.iter().map(trace_response_item_json).collect(), - }; - let Some(response_payload) = write_json_payload_best_effort( - &attempt.context.writer, - RawPayloadKind::InferenceResponse, - &response_payload, + let Some(response_payload) = write_response_payload_best_effort( + attempt, + Some(response_id), + token_usage.as_ref(), + output_items, ) else { return; }; @@ -201,19 +200,59 @@ impl InferenceTraceAttempt { } /// Records pre-response and mid-stream failures. - pub fn record_failed(&self, error: impl Display) { - let InferenceTraceAttemptState::Enabled(attempt) = &self.state else { + pub fn record_failed(&self, error: impl Display, output_items: &[ResponseItem]) { + let Some(attempt) = self.take_terminal_attempt() else { return; }; + let partial_response_payload = if output_items.is_empty() { + None + } else { + write_response_payload_best_effort(attempt, None, None, output_items) + }; append_with_context_best_effort( &attempt.context, RawTraceEventPayload::InferenceFailed { inference_call_id: attempt.inference_call_id.clone(), error: error.to_string(), - partial_response_payload: None, + partial_response_payload, }, ); } + + /// Records a provider stream that Codex intentionally stopped consuming. + /// + /// This happens when the turn is interrupted or when mailbox delivery + /// preempts the current sampling request. Complete output items observed + /// before that point are retained as partial response evidence. + pub fn record_cancelled(&self, reason: impl Display, output_items: &[ResponseItem]) { + let Some(attempt) = self.take_terminal_attempt() else { + return; + }; + let partial_response_payload = if output_items.is_empty() { + None + } else { + write_response_payload_best_effort(attempt, None, None, output_items) + }; + append_with_context_best_effort( + &attempt.context, + RawTraceEventPayload::InferenceCancelled { + inference_call_id: attempt.inference_call_id.clone(), + reason: reason.to_string(), + partial_response_payload, + }, + ); + } + + fn take_terminal_attempt(&self) -> Option<&EnabledInferenceTraceAttempt> { + let attempt = match &self.state { + InferenceTraceAttemptState::Disabled => return None, + InferenceTraceAttemptState::Enabled(attempt) => attempt, + }; + if attempt.terminal_recorded.swap(true, Ordering::AcqRel) { + return None; + } + Some(attempt) + } } /// Serializes a response item for trace evidence rather than future request construction. @@ -260,6 +299,24 @@ fn write_json_payload_best_effort( writer.write_json_payload(kind, payload).ok() } +fn write_response_payload_best_effort( + attempt: &EnabledInferenceTraceAttempt, + response_id: Option<&str>, + token_usage: Option<&TokenUsage>, + output_items: &[ResponseItem], +) -> Option { + let response_payload = TracedResponseStreamOutput { + response_id, + token_usage, + output_items: output_items.iter().map(trace_response_item_json).collect(), + }; + write_json_payload_best_effort( + &attempt.context.writer, + RawPayloadKind::InferenceResponse, + &response_payload, + ) +} + fn append_with_context_best_effort( context: &EnabledInferenceTraceContext, payload: RawTraceEventPayload, diff --git a/codex-rs/rollout-trace/src/raw_event.rs b/codex-rs/rollout-trace/src/raw_event.rs index b364601408e4..ad8f436da7dc 100644 --- a/codex-rs/rollout-trace/src/raw_event.rs +++ b/codex-rs/rollout-trace/src/raw_event.rs @@ -110,6 +110,13 @@ pub enum RawTraceEventPayload { /// Partial response payload, when stream events arrived before failure. partial_response_payload: Option, }, + InferenceCancelled { + inference_call_id: InferenceCallId, + /// Why Codex stopped consuming the provider stream before a terminal response event. + reason: String, + /// Completed output items observed before cancellation, if any. + partial_response_payload: Option, + }, ToolCallStarted { tool_call_id: ToolCallId, /// Protocol/model call ID when this runtime call came from model output. @@ -253,6 +260,10 @@ impl RawTraceEventPayload { partial_response_payload, .. } + | RawTraceEventPayload::InferenceCancelled { + partial_response_payload, + .. + } | RawTraceEventPayload::ToolCallStarted { invocation_payload: partial_response_payload, .. diff --git a/codex-rs/rollout-trace/src/reducer/inference.rs b/codex-rs/rollout-trace/src/reducer/inference.rs index ddd08142ff7c..622fd911025b 100644 --- a/codex-rs/rollout-trace/src/reducer/inference.rs +++ b/codex-rs/rollout-trace/src/reducer/inference.rs @@ -103,6 +103,35 @@ impl TraceReducer { Ok(()) } + /// Closes any inference streams that are still live when the owning turn ends. + /// + /// Normal completion events close the active inference before the turn ends. + /// If a call is still `Running`, Codex stopped observing that provider stream + /// earlier and the reduced graph should not present it as live. + pub(super) fn close_running_inference_calls_for_turn_end( + &mut self, + seq: RawEventSeq, + wall_time_unix_ms: i64, + codex_turn_id: &str, + turn_status: &ExecutionStatus, + ) { + let inference_status = match turn_status { + ExecutionStatus::Running => return, + ExecutionStatus::Completed | ExecutionStatus::Cancelled => ExecutionStatus::Cancelled, + ExecutionStatus::Failed => ExecutionStatus::Failed, + ExecutionStatus::Aborted => ExecutionStatus::Aborted, + }; + for inference in self.rollout.inference_calls.values_mut() { + if inference.codex_turn_id == codex_turn_id + && inference.execution.status == ExecutionStatus::Running + { + inference.execution.ended_at_unix_ms = Some(wall_time_unix_ms); + inference.execution.ended_seq = Some(seq); + inference.execution.status = inference_status.clone(); + } + } + } + /// Completes an inference call and, when present, reduces response output items. pub(super) fn complete_inference_call( &mut self, @@ -141,3 +170,7 @@ impl TraceReducer { Ok(()) } } + +#[cfg(test)] +#[path = "inference_tests.rs"] +mod tests; diff --git a/codex-rs/rollout-trace/src/reducer/inference_tests.rs b/codex-rs/rollout-trace/src/reducer/inference_tests.rs new file mode 100644 index 000000000000..ff5cda08e461 --- /dev/null +++ b/codex-rs/rollout-trace/src/reducer/inference_tests.rs @@ -0,0 +1,92 @@ +use pretty_assertions::assert_eq; +use serde_json::json; +use tempfile::TempDir; + +use crate::model::ConversationItemKind; +use crate::model::ExecutionStatus; +use crate::payload::RawPayloadKind; +use crate::raw_event::RawTraceEventPayload; +use crate::reducer::test_support::append_inference_start; +use crate::reducer::test_support::create_started_writer; +use crate::reducer::test_support::message; +use crate::reducer::test_support::start_turn; +use crate::replay_bundle; + +#[test] +fn cancelled_inference_reduces_partial_response_items() -> anyhow::Result<()> { + let temp = TempDir::new()?; + let writer = create_started_writer(&temp)?; + start_turn(&writer, "turn-1")?; + + let request = writer.write_json_payload( + RawPayloadKind::InferenceRequest, + &json!({ + "input": [message("user", "draft")] + }), + )?; + append_inference_start(&writer, "inference-1", "turn-1", request)?; + + let partial_response = writer.write_json_payload( + RawPayloadKind::InferenceResponse, + &json!({ + "response_id": null, + "token_usage": null, + "output_items": [{ + "type": "message", + "role": "assistant", + "content": [{"type": "output_text", "text": "partial"}] + }] + }), + )?; + writer.append(RawTraceEventPayload::InferenceCancelled { + inference_call_id: "inference-1".to_string(), + reason: "test interruption".to_string(), + partial_response_payload: Some(partial_response), + })?; + + let rollout = replay_bundle(temp.path())?; + let inference = &rollout.inference_calls["inference-1"]; + let response_item_id = &inference.response_item_ids[0]; + + assert_eq!(inference.execution.status, ExecutionStatus::Cancelled); + assert_eq!(inference.response_item_ids.len(), 1); + assert_eq!( + rollout.conversation_items[response_item_id].kind, + ConversationItemKind::Message, + ); + assert_eq!( + rollout.conversation_items[response_item_id].produced_by, + vec![crate::model::ProducerRef::Inference { + inference_call_id: "inference-1".to_string(), + }], + ); + + Ok(()) +} + +#[test] +fn cancelled_turn_closes_running_inference_call() -> anyhow::Result<()> { + let temp = TempDir::new()?; + let writer = create_started_writer(&temp)?; + start_turn(&writer, "turn-1")?; + + let request = writer.write_json_payload( + RawPayloadKind::InferenceRequest, + &json!({ + "input": [message("user", "wait")] + }), + )?; + append_inference_start(&writer, "inference-1", "turn-1", request)?; + let turn_end = writer.append(RawTraceEventPayload::CodexTurnEnded { + codex_turn_id: "turn-1".to_string(), + status: ExecutionStatus::Cancelled, + })?; + + let rollout = replay_bundle(temp.path())?; + let inference = &rollout.inference_calls["inference-1"]; + + assert_eq!(inference.execution.status, ExecutionStatus::Cancelled); + assert_eq!(inference.execution.ended_seq, Some(turn_end.seq)); + + Ok(()) +} diff --git a/codex-rs/rollout-trace/src/reducer/mod.rs b/codex-rs/rollout-trace/src/reducer/mod.rs index e4f6b837f455..34a4e5dcfd25 100644 --- a/codex-rs/rollout-trace/src/reducer/mod.rs +++ b/codex-rs/rollout-trace/src/reducer/mod.rs @@ -254,6 +254,20 @@ impl TraceReducer { partial_response_payload, )?; } + RawTraceEventPayload::InferenceCancelled { + inference_call_id, + partial_response_payload, + .. + } => { + self.complete_inference_call( + event.seq, + event.wall_time_unix_ms, + inference_call_id, + ExecutionStatus::Cancelled, + /*response_id*/ None, + partial_response_payload, + )?; + } RawTraceEventPayload::ProtocolEventObserved { .. } => { // Protocol wrappers are raw debug breadcrumbs. Typed hooks own // the reduced graph, so these payload refs are retained without diff --git a/codex-rs/rollout-trace/src/reducer/thread.rs b/codex-rs/rollout-trace/src/reducer/thread.rs index 6f24f694701d..4ef39ee56309 100644 --- a/codex-rs/rollout-trace/src/reducer/thread.rs +++ b/codex-rs/rollout-trace/src/reducer/thread.rs @@ -187,6 +187,12 @@ impl TraceReducer { &codex_turn_id, &status, )?; + self.close_running_inference_calls_for_turn_end( + seq, + wall_time_unix_ms, + &codex_turn_id, + &status, + ); Ok(()) } diff --git a/codex-rs/rollout-trace/src/thread_tests.rs b/codex-rs/rollout-trace/src/thread_tests.rs index 4d582bbe0f36..a27026f28d8f 100644 --- a/codex-rs/rollout-trace/src/thread_tests.rs +++ b/codex-rs/rollout-trace/src/thread_tests.rs @@ -134,7 +134,7 @@ fn disabled_thread_context_accepts_trace_calls_without_writing() -> anyhow::Resu inference_attempt.record_started(&serde_json::json!({ "kind": "inference" })); let token_usage: Option = None; inference_attempt.record_completed("response-1", &token_usage, &[]); - inference_attempt.record_failed("inference failed"); + inference_attempt.record_failed("inference failed", &[]); let compaction_trace = thread_trace.compaction_trace_context( "turn-1", From c5166eeb7d4ce48039fadea687eae286fbb44919 Mon Sep 17 00:00:00 2001 From: Albin Cassirer Date: Mon, 27 Apr 2026 11:10:05 -0700 Subject: [PATCH 2/4] Address argument-comment lint --- codex-rs/core/src/client.rs | 7 ++++--- codex-rs/rollout-trace/src/inference.rs | 14 ++++++++++++-- 2 files changed, 16 insertions(+), 5 deletions(-) diff --git a/codex-rs/core/src/client.rs b/codex-rs/core/src/client.rs index b57d9e074877..a6b0260330e2 100644 --- a/codex-rs/core/src/client.rs +++ b/codex-rs/core/src/client.rs @@ -1233,7 +1233,8 @@ impl ModelClientSession { Err(ApiError::Transport( unauthorized_transport @ TransportError::Http { status, .. }, )) if status == StatusCode::UNAUTHORIZED => { - inference_trace_attempt.record_failed(&unauthorized_transport, &[]); + inference_trace_attempt + .record_failed(&unauthorized_transport, /*output_items*/ &[]); pending_retry = PendingUnauthorizedRetry::from_recovery( handle_unauthorized( unauthorized_transport, @@ -1246,7 +1247,7 @@ impl ModelClientSession { } Err(err) => { let err = map_api_error(err); - inference_trace_attempt.record_failed(&err, &[]); + inference_trace_attempt.record_failed(&err, /*output_items*/ &[]); return Err(err); } } @@ -1373,7 +1374,7 @@ impl ModelClientSession { .await .map_err(|err| { let err = map_api_error(err); - inference_trace_attempt.record_failed(&err, &[]); + inference_trace_attempt.record_failed(&err, /*output_items*/ &[]); err })?; let (stream, last_request_rx) = map_response_stream( diff --git a/codex-rs/rollout-trace/src/inference.rs b/codex-rs/rollout-trace/src/inference.rs index 3f76cbdf6244..3162308fb37e 100644 --- a/codex-rs/rollout-trace/src/inference.rs +++ b/codex-rs/rollout-trace/src/inference.rs @@ -207,7 +207,12 @@ impl InferenceTraceAttempt { let partial_response_payload = if output_items.is_empty() { None } else { - write_response_payload_best_effort(attempt, None, None, output_items) + write_response_payload_best_effort( + attempt, + /*response_id*/ None, + /*token_usage*/ None, + output_items, + ) }; append_with_context_best_effort( &attempt.context, @@ -231,7 +236,12 @@ impl InferenceTraceAttempt { let partial_response_payload = if output_items.is_empty() { None } else { - write_response_payload_best_effort(attempt, None, None, output_items) + write_response_payload_best_effort( + attempt, + /*response_id*/ None, + /*token_usage*/ None, + output_items, + ) }; append_with_context_best_effort( &attempt.context, From d7b428d308da77a4cc9e4a12966d1ae1c95a5282 Mon Sep 17 00:00:00 2001 From: Albin Cassirer Date: Mon, 27 Apr 2026 12:54:58 -0700 Subject: [PATCH 3/4] Trace response stream send cancellations --- codex-rs/core/src/client.rs | 15 ++- codex-rs/core/src/client_tests.rs | 192 ++++++++++++++++++++++-------- 2 files changed, 150 insertions(+), 57 deletions(-) diff --git a/codex-rs/core/src/client.rs b/codex-rs/core/src/client.rs index a6b0260330e2..3997018eaa5f 100644 --- a/codex-rs/core/src/client.rs +++ b/codex-rs/core/src/client.rs @@ -1623,6 +1623,9 @@ fn parent_thread_id_header_value(session_source: &SessionSource) -> Option( api_stream: S, session_telemetry: SessionTelemetry, @@ -1634,7 +1637,8 @@ where + Send + 'static, { - let (tx_event, rx_event) = mpsc::channel::>(1600); + let (tx_event, rx_event) = + mpsc::channel::>(RESPONSE_STREAM_CHANNEL_CAPACITY); let (tx_last_response, rx_last_response) = oneshot::channel::(); let consumer_dropped = CancellationToken::new(); let consumer_dropped_for_stream = consumer_dropped.clone(); @@ -1647,10 +1651,7 @@ where loop { let event = tokio::select! { _ = consumer_dropped.cancelled() => { - inference_trace_attempt.record_cancelled( - "response stream dropped before provider terminal event", - &items_added, - ); + inference_trace_attempt.record_cancelled(STREAM_DROPPED_REASON, &items_added); return; } event = api_stream.next() => event, @@ -1666,6 +1667,8 @@ where .await .is_err() { + inference_trace_attempt + .record_cancelled(STREAM_DROPPED_REASON, &items_added); return; } } @@ -1708,6 +1711,8 @@ where } Ok(event) => { if tx_event.send(Ok(event)).await.is_err() { + inference_trace_attempt + .record_cancelled(STREAM_DROPPED_REASON, &items_added); return; } } diff --git a/codex-rs/core/src/client_tests.rs b/codex-rs/core/src/client_tests.rs index 4e56d3d90739..e2934dc1ccb2 100644 --- a/codex-rs/core/src/client_tests.rs +++ b/codex-rs/core/src/client_tests.rs @@ -7,6 +7,7 @@ use super::X_CODEX_PARENT_THREAD_ID_HEADER; use super::X_CODEX_TURN_METADATA_HEADER; use super::X_CODEX_WINDOW_ID_HEADER; use super::X_OPENAI_SUBAGENT_HEADER; +use codex_api::ApiError; use codex_api::ResponseEvent; use codex_app_server_protocol::AuthMode; use codex_model_provider::BearerAuthProvider; @@ -20,16 +21,23 @@ use codex_protocol::openai_models::ModelInfo; use codex_protocol::protocol::SessionSource; use codex_protocol::protocol::SubAgentSource; use codex_rollout_trace::ExecutionStatus; +use codex_rollout_trace::InferenceTraceAttempt; use codex_rollout_trace::InferenceTraceContext; use codex_rollout_trace::RawTraceEventPayload; +use codex_rollout_trace::RolloutTrace; use codex_rollout_trace::TraceWriter; use codex_rollout_trace::replay_bundle; use futures::StreamExt; use pretty_assertions::assert_eq; use serde_json::json; +use std::collections::VecDeque; +use std::pin::Pin; use std::sync::Arc; +use std::task::Context; +use std::task::Poll; use std::time::Duration; use tempfile::TempDir; +use tokio::sync::Notify; fn test_model_client(session_source: SessionSource) -> ModelClient { let provider = create_oss_provider_with_base_url("https://example.com/v1", WireApi::Responses); @@ -91,6 +99,92 @@ fn test_session_telemetry() -> SessionTelemetry { ) } +fn started_inference_attempt(temp: &TempDir) -> anyhow::Result { + let writer = Arc::new(TraceWriter::create( + temp.path(), + "trace-1".to_string(), + "rollout-1".to_string(), + "thread-root".to_string(), + )?); + writer.append(RawTraceEventPayload::ThreadStarted { + thread_id: "thread-root".to_string(), + agent_path: "/root".to_string(), + metadata_payload: None, + })?; + writer.append(RawTraceEventPayload::CodexTurnStarted { + codex_turn_id: "turn-1".to_string(), + thread_id: "thread-root".to_string(), + })?; + + let inference_trace = InferenceTraceContext::enabled( + writer, + "thread-root".to_string(), + "turn-1".to_string(), + "gpt-test".to_string(), + "test-provider".to_string(), + ); + let attempt = inference_trace.start_attempt(); + attempt.record_started(&json!({ + "model": "gpt-test", + "input": [{ + "type": "message", + "role": "user", + "content": [{"type": "input_text", "text": "hello"}] + }], + })); + Ok(attempt) +} + +fn output_message(id: &str, text: &str) -> ResponseItem { + ResponseItem::Message { + id: Some(id.to_string()), + role: "assistant".to_string(), + content: vec![ContentItem::OutputText { + text: text.to_string(), + }], + phase: None, + } +} + +async fn replay_until_cancelled(temp: &TempDir) -> anyhow::Result { + let mut rollout = replay_bundle(temp.path())?; + for _ in 0..50 { + let inference = rollout + .inference_calls + .values() + .next() + .expect("inference should be reduced"); + if inference.execution.status == ExecutionStatus::Cancelled { + return Ok(rollout); + } + tokio::time::sleep(Duration::from_millis(10)).await; + rollout = replay_bundle(temp.path())?; + } + Ok(rollout) +} + +struct NotifyAfterEventStream { + events: VecDeque, + yielded: usize, + notify_after: usize, + notify: Arc, +} + +impl futures::Stream for NotifyAfterEventStream { + type Item = std::result::Result; + + fn poll_next(mut self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll> { + let Some(event) = self.events.pop_front() else { + return Poll::Pending; + }; + self.yielded += 1; + if self.yielded == self.notify_after { + self.notify.notify_one(); + } + Poll::Ready(Some(Ok(event))) + } +} + #[test] fn build_subagent_headers_sets_other_subagent_label() { let client = test_model_client(SessionSource::SubAgent(SubAgentSource::Other( @@ -166,51 +260,13 @@ async fn summarize_memories_returns_empty_for_empty_input() { #[tokio::test] async fn dropped_response_stream_traces_cancelled_partial_output() -> anyhow::Result<()> { let temp = TempDir::new()?; - let writer = Arc::new(TraceWriter::create( - temp.path(), - "trace-1".to_string(), - "rollout-1".to_string(), - "thread-root".to_string(), - )?); - writer.append(RawTraceEventPayload::ThreadStarted { - thread_id: "thread-root".to_string(), - agent_path: "/root".to_string(), - metadata_payload: None, - })?; - writer.append(RawTraceEventPayload::CodexTurnStarted { - codex_turn_id: "turn-1".to_string(), - thread_id: "thread-root".to_string(), - })?; - - let inference_trace = InferenceTraceContext::enabled( - writer, - "thread-root".to_string(), - "turn-1".to_string(), - "gpt-test".to_string(), - "test-provider".to_string(), - ); - let attempt = inference_trace.start_attempt(); - attempt.record_started(&json!({ - "model": "gpt-test", - "input": [{ - "type": "message", - "role": "user", - "content": [{"type": "input_text", "text": "hello"}] - }], - })); + let attempt = started_inference_attempt(&temp)?; // The provider has produced one complete output item, but no terminal // response.completed event. The harness has enough information to keep this // item in history, so the trace should preserve it when the stream is // abandoned. - let item = ResponseItem::Message { - id: Some("msg-1".to_string()), - role: "assistant".to_string(), - content: vec![ContentItem::OutputText { - text: "partial answer".to_string(), - }], - phase: None, - }; + let item = output_message("msg-1", "partial answer"); let api_stream = futures::stream::iter([Ok(ResponseEvent::OutputItemDone(item))]) .chain(futures::stream::pending()); let (mut stream, _) = super::map_response_stream(api_stream, test_session_telemetry(), attempt); @@ -228,19 +284,51 @@ async fn dropped_response_stream_traces_cancelled_partial_output() -> anyhow::Re // Cancellation is recorded by the mapper task after Drop wakes it, so the // replay may need a short wait before the terminal event appears on disk. - let mut rollout = replay_bundle(temp.path())?; - for _ in 0..50 { - let inference = rollout - .inference_calls - .values() - .next() - .expect("inference should be reduced"); - if inference.execution.status == ExecutionStatus::Cancelled { - break; - } - tokio::time::sleep(Duration::from_millis(10)).await; - rollout = replay_bundle(temp.path())?; + let rollout = replay_until_cancelled(&temp).await?; + let inference = rollout + .inference_calls + .values() + .next() + .expect("inference should be reduced"); + + assert_eq!(inference.execution.status, ExecutionStatus::Cancelled); + assert_eq!(inference.response_item_ids.len(), 1); + assert_eq!(rollout.raw_payloads.len(), 2); + + Ok(()) +} + +#[tokio::test] +async fn dropped_backpressured_response_stream_traces_cancelled_partial_output() +-> anyhow::Result<()> { + let temp = TempDir::new()?; + let attempt = started_inference_attempt(&temp)?; + let backpressured_item_yielded = Arc::new(Notify::new()); + let mut events = VecDeque::new(); + for _ in 0..super::RESPONSE_STREAM_CHANNEL_CAPACITY { + events.push_back(ResponseEvent::Created); } + events.push_back(ResponseEvent::OutputItemDone(output_message( + "msg-1", + "partial answer", + ))); + let api_stream = NotifyAfterEventStream { + events, + yielded: 0, + notify_after: super::RESPONSE_STREAM_CHANNEL_CAPACITY + 1, + notify: Arc::clone(&backpressured_item_yielded), + }; + + let (stream, _) = super::map_response_stream(api_stream, test_session_telemetry(), attempt); + + // Fill the mapper channel with non-terminal events, then yield one output + // item. The mapper has observed that item and is blocked trying to send it + // downstream, so dropping the consumer covers the send-failure path rather + // than the `consumer_dropped` select branch. + backpressured_item_yielded.notified().await; + drop(stream); + + let rollout = replay_until_cancelled(&temp).await?; let inference = rollout .inference_calls .values() From 5687c27df9902c527e8d4eefa84fe1ee517b434c Mon Sep 17 00:00:00 2001 From: Albin Cassirer Date: Mon, 27 Apr 2026 14:32:11 -0700 Subject: [PATCH 4/4] Respect the final inference status when there is one. --- .../rollout-trace/src/reducer/inference.rs | 17 ++++-- .../src/reducer/inference_tests.rs | 56 +++++++++++++++++++ 2 files changed, 68 insertions(+), 5 deletions(-) diff --git a/codex-rs/rollout-trace/src/reducer/inference.rs b/codex-rs/rollout-trace/src/reducer/inference.rs index 622fd911025b..c65b4236a26a 100644 --- a/codex-rs/rollout-trace/src/reducer/inference.rs +++ b/codex-rs/rollout-trace/src/reducer/inference.rs @@ -159,11 +159,18 @@ impl TraceReducer { let Some(inference) = self.rollout.inference_calls.get_mut(&inference_call_id) else { bail!("inference call {inference_call_id} disappeared during response reduction"); }; - inference.execution.ended_at_unix_ms = Some(wall_time_unix_ms); - inference.execution.ended_seq = Some(seq); - inference.execution.status = status; - inference.upstream_request_id = response_id; - inference.raw_response_payload_id = response_payload.map(|payload| payload.raw_payload_id); + // Turn-end cleanup can close a stream before the async mapper observes + // cancellation. Preserve that terminal status while still retaining any + // late partial response evidence from the mapper. + if inference.execution.status == ExecutionStatus::Running { + inference.execution.ended_at_unix_ms = Some(wall_time_unix_ms); + inference.execution.ended_seq = Some(seq); + inference.execution.status = status; + inference.upstream_request_id = response_id; + } + if let Some(response_payload) = response_payload { + inference.raw_response_payload_id = Some(response_payload.raw_payload_id); + } if let Some(response_item_ids) = response_item_ids { inference.response_item_ids = response_item_ids; } diff --git a/codex-rs/rollout-trace/src/reducer/inference_tests.rs b/codex-rs/rollout-trace/src/reducer/inference_tests.rs index ff5cda08e461..d5ba0ba9d02f 100644 --- a/codex-rs/rollout-trace/src/reducer/inference_tests.rs +++ b/codex-rs/rollout-trace/src/reducer/inference_tests.rs @@ -90,3 +90,59 @@ fn cancelled_turn_closes_running_inference_call() -> anyhow::Result<()> { Ok(()) } + +#[test] +fn late_cancelled_inference_preserves_turn_end_status() -> anyhow::Result<()> { + let temp = TempDir::new()?; + let writer = create_started_writer(&temp)?; + start_turn(&writer, "turn-1")?; + + let request = writer.write_json_payload( + RawPayloadKind::InferenceRequest, + &json!({ + "input": [message("user", "interrupt")] + }), + )?; + append_inference_start(&writer, "inference-1", "turn-1", request)?; + let turn_end = writer.append(RawTraceEventPayload::CodexTurnEnded { + codex_turn_id: "turn-1".to_string(), + status: ExecutionStatus::Failed, + })?; + + let partial_response = writer.write_json_payload( + RawPayloadKind::InferenceResponse, + &json!({ + "response_id": null, + "token_usage": null, + "output_items": [{ + "type": "message", + "role": "assistant", + "content": [{"type": "output_text", "text": "late partial"}] + }] + }), + )?; + writer.append(RawTraceEventPayload::InferenceCancelled { + inference_call_id: "inference-1".to_string(), + reason: "stream mapper noticed cancellation after turn end".to_string(), + partial_response_payload: Some(partial_response.clone()), + })?; + + let rollout = replay_bundle(temp.path())?; + let inference = &rollout.inference_calls["inference-1"]; + assert_eq!(inference.execution.status, ExecutionStatus::Failed); + assert_eq!(inference.execution.ended_seq, Some(turn_end.seq)); + assert_eq!( + inference.raw_response_payload_id, + Some(partial_response.raw_payload_id), + ); + assert_eq!(inference.response_item_ids.len(), 1); + let response_item_id = &inference.response_item_ids[0]; + assert_eq!( + rollout.conversation_items[response_item_id].body.parts, + vec![crate::model::ConversationPart::Text { + text: "late partial".to_string(), + }], + ); + + Ok(()) +}