Skip to content
Closed
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
212 changes: 200 additions & 12 deletions codex-rs/exec-server/src/client.rs
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ use futures::FutureExt;
use futures::future::BoxFuture;
use serde_json::Value;
use tokio::sync::Mutex;
use tokio::sync::OnceCell;
use tokio::sync::Semaphore;
use tokio::sync::mpsc;
use tokio::sync::watch;

Expand Down Expand Up @@ -192,27 +192,62 @@ pub struct ExecServerClient {
#[derive(Clone)]
pub(crate) struct LazyRemoteExecServerClient {
transport_params: ExecServerTransportParams,
client: Arc<OnceCell<ExecServerClient>>,
client: Arc<StdMutex<Option<ExecServerClient>>>,
connect_lock: Arc<Semaphore>,
}

impl LazyRemoteExecServerClient {
pub(crate) fn new(transport_params: ExecServerTransportParams) -> Self {
Self {
transport_params,
client: Arc::new(OnceCell::new()),
client: Arc::new(StdMutex::new(None)),
connect_lock: Arc::new(Semaphore::new(/*permits*/ 1)),
}
}

pub(crate) async fn get(&self) -> Result<ExecServerClient, ExecServerError> {
if let Some(client) = self.connected_client() {
return Ok(client);
}

let _connect_permit = self.connect_lock.acquire().await.map_err(|_| {
ExecServerError::Protocol("exec-server connect lock closed".to_string())
})?;
if let Some(client) = self.connected_client() {
return Ok(client);
}

let next_client = match self.cached_client() {
Some(client)
if matches!(
&self.transport_params,
ExecServerTransportParams::WebSocketUrl { .. }
) =>
{
ExecServerClient::connect_for_transport(self.transport_params.clone()).await?
}
Some(client) => return Ok(client),
None => ExecServerClient::connect_for_transport(self.transport_params.clone()).await?,
};

let mut cached_client = self
.client
.lock()
.unwrap_or_else(std::sync::PoisonError::into_inner);
*cached_client = Some(next_client.clone());
Ok(next_client)
}

fn connected_client(&self) -> Option<ExecServerClient> {
self.cached_client()
.filter(|client| !client.is_disconnected())
}

fn cached_client(&self) -> Option<ExecServerClient> {
self.client
// TODO: Add reconnect/disconnect handling here instead of reusing
// the first successfully initialized connection forever.
.get_or_try_init(|| {
let transport_params = self.transport_params.clone();
async move { ExecServerClient::connect_for_transport(transport_params).await }
})
.await
.cloned()
.lock()
.unwrap_or_else(std::sync::PoisonError::into_inner)
.clone()
}
}

Expand Down Expand Up @@ -424,6 +459,10 @@ impl ExecServerClient {
.clone()
}

fn is_disconnected(&self) -> bool {
self.inner.disconnected.get().is_some()
}

pub(crate) async fn connect(
connection: JsonRpcConnection,
options: ExecServerClientConnectOptions,
Expand Down Expand Up @@ -873,30 +912,38 @@ mod tests {
use codex_app_server_protocol::JSONRPCMessage;
use codex_app_server_protocol::JSONRPCNotification;
use codex_app_server_protocol::JSONRPCResponse;
use futures::SinkExt;
use futures::StreamExt;
use pretty_assertions::assert_eq;
use std::collections::HashMap;
#[cfg(unix)]
use std::path::Path;
#[cfg(unix)]
use std::process::Command;
use std::sync::Arc;
use tokio::io::AsyncBufReadExt;
use tokio::io::AsyncWrite;
use tokio::io::AsyncWriteExt;
use tokio::io::BufReader;
use tokio::io::duplex;
use tokio::net::TcpListener;
use tokio::net::TcpStream;
use tokio::sync::mpsc;
use tokio::sync::oneshot;
use tokio::time::Duration;
#[cfg(unix)]
use tokio::time::sleep;
use tokio::time::timeout;
use tokio_tungstenite::WebSocketStream;
use tokio_tungstenite::accept_async;
use tokio_tungstenite::tungstenite::Message;

use super::ExecServerClient;
use super::ExecServerClientConnectOptions;
use super::LazyRemoteExecServerClient;
use crate::ProcessId;
#[cfg(not(windows))]
use crate::client_api::DEFAULT_REMOTE_EXEC_SERVER_INITIALIZE_TIMEOUT;
#[cfg(not(windows))]
use crate::client_api::ExecServerTransportParams;
use crate::client_api::StdioExecServerCommand;
use crate::client_api::StdioExecServerConnectArgs;
Expand Down Expand Up @@ -937,6 +984,96 @@ mod tests {
.expect("json-rpc line should write");
}

async fn accept_websocket(listener: &TcpListener) -> WebSocketStream<TcpStream> {
let (stream, _) = listener.accept().await.expect("listener should accept");
accept_async(stream)
.await
.expect("websocket handshake should succeed")
}

async fn read_jsonrpc_websocket(websocket: &mut WebSocketStream<TcpStream>) -> JSONRPCMessage {
loop {
match timeout(Duration::from_secs(1), websocket.next())
.await
.expect("json-rpc websocket read should not time out")
.expect("websocket should stay open")
.expect("websocket frame should read")
{
Message::Text(text) => {
return serde_json::from_str(text.as_ref())
.expect("json-rpc text frame should parse");
}
Message::Binary(bytes) => {
return serde_json::from_slice(bytes.as_ref())
.expect("json-rpc binary frame should parse");
}
Message::Ping(_) | Message::Pong(_) => {}
other => panic!("expected json-rpc websocket frame, got {other:?}"),
}
}
}

async fn write_jsonrpc_websocket(
websocket: &mut WebSocketStream<TcpStream>,
message: JSONRPCMessage,
) {
let encoded = serde_json::to_string(&message).expect("json-rpc should serialize");
websocket
.send(Message::Text(encoded.into()))
.await
.expect("json-rpc websocket frame should write");
}

async fn complete_websocket_initialize(
websocket: &mut WebSocketStream<TcpStream>,
session_id: &str,
expected_resume_session_id: Option<&str>,
) {
let initialize = read_jsonrpc_websocket(websocket).await;
let request = match initialize {
JSONRPCMessage::Request(request) if request.method == INITIALIZE_METHOD => request,
other => panic!("expected initialize request, got {other:?}"),
};
let params: crate::protocol::InitializeParams =
serde_json::from_value(request.params.expect("initialize params should exist"))
.expect("initialize params should deserialize");
assert_eq!(
params.resume_session_id.as_deref(),
expected_resume_session_id
);
write_jsonrpc_websocket(
websocket,
JSONRPCMessage::Response(JSONRPCResponse {
id: request.id,
result: serde_json::to_value(InitializeResponse {
session_id: session_id.to_string(),
})
.expect("initialize response should serialize"),
}),
)
.await;

let initialized = read_jsonrpc_websocket(websocket).await;
match initialized {
JSONRPCMessage::Notification(notification)
if notification.method == INITIALIZED_METHOD => {}
other => panic!("expected initialized notification, got {other:?}"),
}
}

async fn wait_for_disconnect(client: &ExecServerClient) {
timeout(Duration::from_secs(1), async {
loop {
if client.is_disconnected() {
return;
}
tokio::task::yield_now().await;
}
})
.await
.expect("client should observe disconnect");
}

#[cfg(not(windows))]
#[tokio::test]
async fn connect_stdio_command_initializes_json_rpc_client() {
Expand Down Expand Up @@ -1354,6 +1491,57 @@ mod tests {
server.await.expect("server task should finish");
}

#[tokio::test]
async fn remote_websocket_client_replaces_disconnected_client_with_fresh_session() {
let listener = TcpListener::bind("127.0.0.1:0")
.await
.expect("listener should bind");
let websocket_url = format!(
"ws://{}",
listener.local_addr().expect("listener should have address")
);
let server = tokio::spawn({
async move {
let mut first = accept_websocket(&listener).await;
complete_websocket_initialize(
&mut first,
"session-1",
/*expected_resume_session_id*/ None,
)
.await;
first
.close(None)
.await
.expect("first websocket should close");

let mut second = accept_websocket(&listener).await;
complete_websocket_initialize(
&mut second,
"session-2",
/*expected_resume_session_id*/ None,
)
.await;
}
});

let client = LazyRemoteExecServerClient::new(ExecServerTransportParams::WebSocketUrl {
websocket_url,
connect_timeout: Duration::from_secs(1),
initialize_timeout: Duration::from_secs(1),
});
let first = client.get().await.expect("first client should connect");
wait_for_disconnect(&first).await;

let (replacement_a, replacement_b) = tokio::join!(client.get(), client.get());
let replacement_a = replacement_a.expect("first replacement should connect");
let replacement_b = replacement_b.expect("second replacement should reuse client");
assert_eq!(replacement_a.session_id().as_deref(), Some("session-2"));
assert_eq!(replacement_b.session_id().as_deref(), Some("session-2"));
assert!(Arc::ptr_eq(&replacement_a.inner, &replacement_b.inner));

server.await.expect("server task should finish");
}

#[tokio::test]
async fn wake_notifications_do_not_block_other_sessions() {
let (client_stdin, server_reader) = duplex(1 << 20);
Expand Down
Loading