Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
30 changes: 27 additions & 3 deletions code-rs/core/src/client.rs
Original file line number Diff line number Diff line change
Expand Up @@ -23,8 +23,9 @@ use reqwest::header::HeaderValue;
use serde::Deserialize;
use serde::Serialize;
use serde_json::Value;
use tokio::sync::mpsc;
use tokio::sync::Mutex as TokioMutex;
use tokio::sync::mpsc;
use tokio::sync::oneshot;
use tokio::time::timeout;
use tokio_util::io::ReaderStream;
use tokio_stream::wrappers::ReceiverStream;
Expand Down Expand Up @@ -131,6 +132,7 @@ struct ResponsesWebsocketSession {
turn_state: Arc<OnceLock<String>>,
last_request: Option<ResponsesRequestSnapshot>,
last_response_id: Option<String>,
last_response_from_warmup: bool,
}

impl std::fmt::Debug for ResponsesWebsocketSession {
Expand All @@ -140,6 +142,7 @@ impl std::fmt::Debug for ResponsesWebsocketSession {
.field("has_turn_state", &self.turn_state.get().is_some())
.field("has_last_request", &self.last_request.is_some())
.field("has_last_response_id", &self.last_response_id.is_some())
.field("last_response_from_warmup", &self.last_response_from_warmup)
.finish()
}
}
Expand Down Expand Up @@ -1081,8 +1084,13 @@ impl ModelClient {
previous,
&current_snapshot,
) {
Some(input) => (Some(input), Some(response_id.clone())),
Some(input)
if !input.is_empty() || session.last_response_from_warmup =>
{
(Some(input), Some(response_id.clone()))
}
None => (None, None),
_ => (None, None),
}
}
_ => (None, None),
Expand Down Expand Up @@ -1125,8 +1133,11 @@ impl ModelClient {
warn!(
existing,
new = value,
"received unexpected x-codex-turn-state during websocket connect"
"received new x-codex-turn-state during websocket connect"
);
let refreshed = Arc::new(OnceLock::new());
let _ = refreshed.set(value.to_string());
session.turn_state = refreshed;
} else {
let _ = session.turn_state.set(value.to_string());
}
Expand Down Expand Up @@ -1224,6 +1235,7 @@ impl ModelClient {
session.connection = None;
session.last_request = None;
session.last_response_id = None;
session.last_response_from_warmup = false;
let err = CodexErr::Stream(
format!("[ws] failed to send websocket request: {err}"),
None,
Expand All @@ -1247,8 +1259,10 @@ impl ModelClient {
mpsc::channel::<Result<Bytes>>(RESPONSES_WEBSOCKET_INGRESS_BUFFER);
let request_id_for_ws = request_id.clone();
let websocket_session = Arc::clone(&self.websocket_session);
let (reader_ready_tx, reader_ready_rx) = oneshot::channel();
tokio::spawn(async move {
let mut session = websocket_session.lock().await;
let _ = reader_ready_tx.send(());
let Some(ws_stream) = session.connection.as_mut() else {
let _ = tx_bytes
.send(Err(CodexErr::Stream(
Expand All @@ -1265,6 +1279,7 @@ impl ModelClient {
session.connection = None;
session.last_request = None;
session.last_response_id = None;
session.last_response_from_warmup = false;
break;
};
match next {
Expand All @@ -1275,6 +1290,7 @@ impl ModelClient {
session.connection = None;
session.last_request = None;
session.last_response_id = None;
session.last_response_from_warmup = false;
let _ = tx_bytes.send(Err(error)).await;
break;
}
Expand All @@ -1289,10 +1305,12 @@ impl ModelClient {
Some(response_id) if !response_id.is_empty() => {
session.last_request = Some(current_snapshot);
session.last_response_id = Some(response_id);
session.last_response_from_warmup = warmup;
}
_ => {
session.last_request = None;
session.last_response_id = None;
session.last_response_from_warmup = false;
}
}
break;
Expand All @@ -1307,12 +1325,16 @@ impl ModelClient {
Ok(Message::Pong(_)) => {}
Ok(Message::Close(_)) => {
session.connection = None;
session.last_request = None;
session.last_response_id = None;
session.last_response_from_warmup = false;
break;
}
Ok(Message::Binary(_)) => {
session.connection = None;
session.last_request = None;
session.last_response_id = None;
session.last_response_from_warmup = false;
let _ = tx_bytes
.send(Err(CodexErr::Stream(
"[ws] unexpected binary websocket event".to_string(),
Expand All @@ -1327,6 +1349,7 @@ impl ModelClient {
session.connection = None;
session.last_request = None;
session.last_response_id = None;
session.last_response_from_warmup = false;
let _ = tx_bytes
.send(Err(CodexErr::Stream(
format!("[ws] websocket error: {err}"),
Expand All @@ -1339,6 +1362,7 @@ impl ModelClient {
}
}
});
let _ = reader_ready_rx.await;

let stream = ReceiverStream::new(rx_bytes);
let debug_logger = Arc::clone(&self.debug_logger);
Expand Down