Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
44 changes: 37 additions & 7 deletions codex-rs/core/src/client.rs
Original file line number Diff line number Diff line change
Expand Up @@ -98,6 +98,7 @@ use tokio::sync::oneshot;
use tokio::sync::oneshot::error::TryRecvError;
use tokio_tungstenite::tungstenite::Error;
use tokio_tungstenite::tungstenite::Message;
use tokio_util::sync::CancellationToken;
use tracing::instrument;
use tracing::trace;
use tracing::warn;
Expand Down Expand Up @@ -1232,7 +1233,8 @@ impl ModelClientSession {
Err(ApiError::Transport(
unauthorized_transport @ TransportError::Http { status, .. },
)) if status == StatusCode::UNAUTHORIZED => {
inference_trace_attempt.record_failed(&unauthorized_transport);
inference_trace_attempt
.record_failed(&unauthorized_transport, /*output_items*/ &[]);
pending_retry = PendingUnauthorizedRetry::from_recovery(
handle_unauthorized(
unauthorized_transport,
Expand All @@ -1245,7 +1247,7 @@ impl ModelClientSession {
}
Err(err) => {
let err = map_api_error(err);
inference_trace_attempt.record_failed(&err);
inference_trace_attempt.record_failed(&err, /*output_items*/ &[]);
return Err(err);
}
}
Expand Down Expand Up @@ -1372,7 +1374,7 @@ impl ModelClientSession {
.await
.map_err(|err| {
let err = map_api_error(err);
inference_trace_attempt.record_failed(&err);
inference_trace_attempt.record_failed(&err, /*output_items*/ &[]);
err
})?;
let (stream, last_request_rx) = map_response_stream(
Expand Down Expand Up @@ -1621,6 +1623,9 @@ fn parent_thread_id_header_value(session_source: &SessionSource) -> Option<Strin
}
}

const RESPONSE_STREAM_CHANNEL_CAPACITY: usize = 1600;
const STREAM_DROPPED_REASON: &str = "response stream dropped before provider terminal event";

fn map_response_stream<S>(
api_stream: S,
session_telemetry: SessionTelemetry,
Expand All @@ -1632,15 +1637,28 @@ where
+ Send
+ 'static,
{
let (tx_event, rx_event) = mpsc::channel::<Result<ResponseEvent>>(1600);
let (tx_event, rx_event) =
mpsc::channel::<Result<ResponseEvent>>(RESPONSE_STREAM_CHANNEL_CAPACITY);
let (tx_last_response, rx_last_response) = oneshot::channel::<LastResponse>();
let consumer_dropped = CancellationToken::new();
let consumer_dropped_for_stream = consumer_dropped.clone();

tokio::spawn(async move {
let mut logged_error = false;
let mut tx_last_response = Some(tx_last_response);
let mut items_added: Vec<ResponseItem> = Vec::new();
let mut api_stream = api_stream;
while let Some(event) = api_stream.next().await {
loop {
let event = tokio::select! {
_ = consumer_dropped.cancelled() => {
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This only records cancellation while the mapper is waiting on api_stream.next(). If the consumer is dropped while we’re blocked in one of the tx_event.send(...).await paths, we loose it. So an interrupted stream can lose what we're trying to do here

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yep, good catch. The token only covers the api_stream.next() wait. I added explicit cancellation recording on the tx_event.send(...).await error paths too, so if the receiver is dropped while a send is blocked we still close the inference as cancelled and keep the items observed so far.

inference_trace_attempt.record_cancelled(STREAM_DROPPED_REASON, &items_added);
return;
}
event = api_stream.next() => event,
};
let Some(event) = event else {
break;
};
match event {
Ok(ResponseEvent::OutputItemDone(item)) => {
items_added.push(item.clone());
Expand All @@ -1649,6 +1667,8 @@ where
.await
.is_err()
{
inference_trace_attempt
.record_cancelled(STREAM_DROPPED_REASON, &items_added);
return;
}
}
Expand Down Expand Up @@ -1691,12 +1711,14 @@ where
}
Ok(event) => {
if tx_event.send(Ok(event)).await.is_err() {
inference_trace_attempt
.record_cancelled(STREAM_DROPPED_REASON, &items_added);
return;
}
}
Err(err) => {
let mapped = map_api_error(err);
inference_trace_attempt.record_failed(&mapped);
inference_trace_attempt.record_failed(&mapped, &items_added);
if !logged_error {
session_telemetry.see_event_completed_failed(&mapped);
logged_error = true;
Expand All @@ -1707,9 +1729,17 @@ where
}
}
}
inference_trace_attempt
.record_failed("stream closed before response.completed", &items_added);
});

(ResponseStream { rx_event }, rx_last_response)
(
ResponseStream {
rx_event,
consumer_dropped: consumer_dropped_for_stream,
},
rx_last_response,
)
}

/// Handles a 401 response by optionally refreshing ChatGPT tokens once.
Expand Down
10 changes: 10 additions & 0 deletions codex-rs/core/src/client_common.rs
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@ use std::pin::Pin;
use std::task::Context;
use std::task::Poll;
use tokio::sync::mpsc;
use tokio_util::sync::CancellationToken;

/// Review thread system prompt. Edit `core/src/review_prompt.md` to customize.
pub const REVIEW_PROMPT: &str = include_str!("../review_prompt.md");
Expand Down Expand Up @@ -175,6 +176,9 @@ fn strip_total_output_header(output: &str) -> Option<(&str, u32)> {

pub struct ResponseStream {
pub(crate) rx_event: mpsc::Receiver<Result<ResponseEvent>>,
/// Signals the mapper task that the consumer stopped polling before the
/// provider stream reached its own terminal event.
pub(crate) consumer_dropped: CancellationToken,
}

impl Stream for ResponseStream {
Expand All @@ -185,6 +189,12 @@ impl Stream for ResponseStream {
}
}

impl Drop for ResponseStream {
fn drop(&mut self) {
self.consumer_dropped.cancel();
}
}

#[cfg(test)]
#[path = "client_common_tests.rs"]
mod tests;
191 changes: 191 additions & 0 deletions codex-rs/core/src/client_tests.rs
Original file line number Diff line number Diff line change
Expand Up @@ -7,17 +7,37 @@ use super::X_CODEX_PARENT_THREAD_ID_HEADER;
use super::X_CODEX_TURN_METADATA_HEADER;
use super::X_CODEX_WINDOW_ID_HEADER;
use super::X_OPENAI_SUBAGENT_HEADER;
use codex_api::ApiError;
use codex_api::ResponseEvent;
use codex_app_server_protocol::AuthMode;
use codex_model_provider::BearerAuthProvider;
use codex_model_provider_info::WireApi;
use codex_model_provider_info::create_oss_provider_with_base_url;
use codex_otel::SessionTelemetry;
use codex_protocol::ThreadId;
use codex_protocol::models::ContentItem;
use codex_protocol::models::ResponseItem;
use codex_protocol::openai_models::ModelInfo;
use codex_protocol::protocol::SessionSource;
use codex_protocol::protocol::SubAgentSource;
use codex_rollout_trace::ExecutionStatus;
use codex_rollout_trace::InferenceTraceAttempt;
use codex_rollout_trace::InferenceTraceContext;
use codex_rollout_trace::RawTraceEventPayload;
use codex_rollout_trace::RolloutTrace;
use codex_rollout_trace::TraceWriter;
use codex_rollout_trace::replay_bundle;
use futures::StreamExt;
use pretty_assertions::assert_eq;
use serde_json::json;
use std::collections::VecDeque;
use std::pin::Pin;
use std::sync::Arc;
use std::task::Context;
use std::task::Poll;
use std::time::Duration;
use tempfile::TempDir;
use tokio::sync::Notify;

fn test_model_client(session_source: SessionSource) -> ModelClient {
let provider = create_oss_provider_with_base_url("https://example.com/v1", WireApi::Responses);
Expand Down Expand Up @@ -79,6 +99,92 @@ fn test_session_telemetry() -> SessionTelemetry {
)
}

fn started_inference_attempt(temp: &TempDir) -> anyhow::Result<InferenceTraceAttempt> {
let writer = Arc::new(TraceWriter::create(
temp.path(),
"trace-1".to_string(),
"rollout-1".to_string(),
"thread-root".to_string(),
)?);
writer.append(RawTraceEventPayload::ThreadStarted {
thread_id: "thread-root".to_string(),
agent_path: "/root".to_string(),
metadata_payload: None,
})?;
writer.append(RawTraceEventPayload::CodexTurnStarted {
codex_turn_id: "turn-1".to_string(),
thread_id: "thread-root".to_string(),
})?;

let inference_trace = InferenceTraceContext::enabled(
writer,
"thread-root".to_string(),
"turn-1".to_string(),
"gpt-test".to_string(),
"test-provider".to_string(),
);
let attempt = inference_trace.start_attempt();
attempt.record_started(&json!({
"model": "gpt-test",
"input": [{
"type": "message",
"role": "user",
"content": [{"type": "input_text", "text": "hello"}]
}],
}));
Ok(attempt)
}

fn output_message(id: &str, text: &str) -> ResponseItem {
ResponseItem::Message {
id: Some(id.to_string()),
role: "assistant".to_string(),
content: vec![ContentItem::OutputText {
text: text.to_string(),
}],
phase: None,
}
}

async fn replay_until_cancelled(temp: &TempDir) -> anyhow::Result<RolloutTrace> {
let mut rollout = replay_bundle(temp.path())?;
for _ in 0..50 {
let inference = rollout
.inference_calls
.values()
.next()
.expect("inference should be reduced");
if inference.execution.status == ExecutionStatus::Cancelled {
return Ok(rollout);
}
tokio::time::sleep(Duration::from_millis(10)).await;
rollout = replay_bundle(temp.path())?;
}
Ok(rollout)
}

struct NotifyAfterEventStream {
events: VecDeque<ResponseEvent>,
yielded: usize,
notify_after: usize,
notify: Arc<Notify>,
}

impl futures::Stream for NotifyAfterEventStream {
type Item = std::result::Result<ResponseEvent, ApiError>;

fn poll_next(mut self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
let Some(event) = self.events.pop_front() else {
return Poll::Pending;
};
self.yielded += 1;
if self.yielded == self.notify_after {
self.notify.notify_one();
}
Poll::Ready(Some(Ok(event)))
}
}

#[test]
fn build_subagent_headers_sets_other_subagent_label() {
let client = test_model_client(SessionSource::SubAgent(SubAgentSource::Other(
Expand Down Expand Up @@ -151,6 +257,91 @@ async fn summarize_memories_returns_empty_for_empty_input() {
assert_eq!(output.len(), 0);
}

#[tokio::test]
async fn dropped_response_stream_traces_cancelled_partial_output() -> anyhow::Result<()> {
let temp = TempDir::new()?;
let attempt = started_inference_attempt(&temp)?;

// The provider has produced one complete output item, but no terminal
// response.completed event. The harness has enough information to keep this
// item in history, so the trace should preserve it when the stream is
// abandoned.
let item = output_message("msg-1", "partial answer");
let api_stream = futures::stream::iter([Ok(ResponseEvent::OutputItemDone(item))])
.chain(futures::stream::pending());
let (mut stream, _) = super::map_response_stream(api_stream, test_session_telemetry(), attempt);

let observed = stream
.next()
.await
.expect("mapped stream should yield output item")?;
assert!(matches!(observed, ResponseEvent::OutputItemDone(_)));

// Dropping the consumer is how turn interruption/preemption stops polling
// the provider stream. The mapper task observes that drop asynchronously
// and records cancellation using the output items it has already seen.
drop(stream);

// Cancellation is recorded by the mapper task after Drop wakes it, so the
// replay may need a short wait before the terminal event appears on disk.
let rollout = replay_until_cancelled(&temp).await?;
let inference = rollout
.inference_calls
.values()
.next()
.expect("inference should be reduced");

assert_eq!(inference.execution.status, ExecutionStatus::Cancelled);
assert_eq!(inference.response_item_ids.len(), 1);
assert_eq!(rollout.raw_payloads.len(), 2);

Ok(())
}

#[tokio::test]
async fn dropped_backpressured_response_stream_traces_cancelled_partial_output()
-> anyhow::Result<()> {
let temp = TempDir::new()?;
let attempt = started_inference_attempt(&temp)?;
let backpressured_item_yielded = Arc::new(Notify::new());
let mut events = VecDeque::new();
for _ in 0..super::RESPONSE_STREAM_CHANNEL_CAPACITY {
events.push_back(ResponseEvent::Created);
}
events.push_back(ResponseEvent::OutputItemDone(output_message(
"msg-1",
"partial answer",
)));
let api_stream = NotifyAfterEventStream {
events,
yielded: 0,
notify_after: super::RESPONSE_STREAM_CHANNEL_CAPACITY + 1,
notify: Arc::clone(&backpressured_item_yielded),
};

let (stream, _) = super::map_response_stream(api_stream, test_session_telemetry(), attempt);

// Fill the mapper channel with non-terminal events, then yield one output
// item. The mapper has observed that item and is blocked trying to send it
// downstream, so dropping the consumer covers the send-failure path rather
// than the `consumer_dropped` select branch.
backpressured_item_yielded.notified().await;
drop(stream);

let rollout = replay_until_cancelled(&temp).await?;
let inference = rollout
.inference_calls
.values()
.next()
.expect("inference should be reduced");

assert_eq!(inference.execution.status, ExecutionStatus::Cancelled);
assert_eq!(inference.response_item_ids.len(), 1);
assert_eq!(rollout.raw_payloads.len(), 2);

Ok(())
}

#[test]
fn auth_request_telemetry_context_tracks_attached_auth_and_retry_phase() {
let auth_context = AuthRequestTelemetryContext::new(
Expand Down
Loading
Loading