diff --git a/code-rs/core/src/client.rs b/code-rs/core/src/client.rs index 90fd36f3a9c1..e561d3c1e0f5 100644 --- a/code-rs/core/src/client.rs +++ b/code-rs/core/src/client.rs @@ -23,8 +23,9 @@ use reqwest::header::HeaderValue; use serde::Deserialize; use serde::Serialize; use serde_json::Value; -use tokio::sync::mpsc; use tokio::sync::Mutex as TokioMutex; +use tokio::sync::mpsc; +use tokio::sync::oneshot; use tokio::time::timeout; use tokio_util::io::ReaderStream; use tokio_stream::wrappers::ReceiverStream; @@ -131,6 +132,7 @@ struct ResponsesWebsocketSession { turn_state: Arc>, last_request: Option, last_response_id: Option, + last_response_from_warmup: bool, } impl std::fmt::Debug for ResponsesWebsocketSession { @@ -140,6 +142,7 @@ impl std::fmt::Debug for ResponsesWebsocketSession { .field("has_turn_state", &self.turn_state.get().is_some()) .field("has_last_request", &self.last_request.is_some()) .field("has_last_response_id", &self.last_response_id.is_some()) + .field("last_response_from_warmup", &self.last_response_from_warmup) .finish() } } @@ -1081,8 +1084,13 @@ impl ModelClient { previous, ¤t_snapshot, ) { - Some(input) => (Some(input), Some(response_id.clone())), + Some(input) + if !input.is_empty() || session.last_response_from_warmup => + { + (Some(input), Some(response_id.clone())) + } None => (None, None), + _ => (None, None), } } _ => (None, None), @@ -1125,8 +1133,11 @@ impl ModelClient { warn!( existing, new = value, - "received unexpected x-codex-turn-state during websocket connect" + "received new x-codex-turn-state during websocket connect" ); + let refreshed = Arc::new(OnceLock::new()); + let _ = refreshed.set(value.to_string()); + session.turn_state = refreshed; } else { let _ = session.turn_state.set(value.to_string()); } @@ -1224,6 +1235,7 @@ impl ModelClient { session.connection = None; session.last_request = None; session.last_response_id = None; + session.last_response_from_warmup = false; let err = CodexErr::Stream( format!("[ws] failed to send websocket request: {err}"), None, @@ -1247,8 +1259,10 @@ impl ModelClient { mpsc::channel::>(RESPONSES_WEBSOCKET_INGRESS_BUFFER); let request_id_for_ws = request_id.clone(); let websocket_session = Arc::clone(&self.websocket_session); + let (reader_ready_tx, reader_ready_rx) = oneshot::channel(); tokio::spawn(async move { let mut session = websocket_session.lock().await; + let _ = reader_ready_tx.send(()); let Some(ws_stream) = session.connection.as_mut() else { let _ = tx_bytes .send(Err(CodexErr::Stream( @@ -1265,6 +1279,7 @@ impl ModelClient { session.connection = None; session.last_request = None; session.last_response_id = None; + session.last_response_from_warmup = false; break; }; match next { @@ -1275,6 +1290,7 @@ impl ModelClient { session.connection = None; session.last_request = None; session.last_response_id = None; + session.last_response_from_warmup = false; let _ = tx_bytes.send(Err(error)).await; break; } @@ -1289,10 +1305,12 @@ impl ModelClient { Some(response_id) if !response_id.is_empty() => { session.last_request = Some(current_snapshot); session.last_response_id = Some(response_id); + session.last_response_from_warmup = warmup; } _ => { session.last_request = None; session.last_response_id = None; + session.last_response_from_warmup = false; } } break; @@ -1307,12 +1325,16 @@ impl ModelClient { Ok(Message::Pong(_)) => {} Ok(Message::Close(_)) => { session.connection = None; + session.last_request = None; + session.last_response_id = None; + session.last_response_from_warmup = false; break; } Ok(Message::Binary(_)) => { session.connection = None; session.last_request = None; session.last_response_id = None; + session.last_response_from_warmup = false; let _ = tx_bytes .send(Err(CodexErr::Stream( "[ws] unexpected binary websocket event".to_string(), @@ -1327,6 +1349,7 @@ impl ModelClient { session.connection = None; session.last_request = None; session.last_response_id = None; + session.last_response_from_warmup = false; let _ = tx_bytes .send(Err(CodexErr::Stream( format!("[ws] websocket error: {err}"), @@ -1339,6 +1362,7 @@ impl ModelClient { } } }); + let _ = reader_ready_rx.await; let stream = ReceiverStream::new(rx_bytes); let debug_logger = Arc::clone(&self.debug_logger);