diff --git a/codex-rs/Cargo.lock b/codex-rs/Cargo.lock index 6bc53a49fbc8..1e2e3a302bd1 100644 --- a/codex-rs/Cargo.lock +++ b/codex-rs/Cargo.lock @@ -1868,6 +1868,7 @@ dependencies = [ "codex-state", "codex-thread-store", "codex-tools", + "codex-uds", "codex-utils-absolute-path", "codex-utils-cargo-bin", "codex-utils-cli", diff --git a/codex-rs/app-server/Cargo.toml b/codex-rs/app-server/Cargo.toml index 339bc20f10f0..de1fd74aa69c 100644 --- a/codex-rs/app-server/Cargo.toml +++ b/codex-rs/app-server/Cargo.toml @@ -59,6 +59,7 @@ codex-sandboxing = { workspace = true } codex-state = { workspace = true } codex-thread-store = { workspace = true } codex-tools = { workspace = true } +codex-uds = { workspace = true } codex-utils-absolute-path = { workspace = true } codex-utils-json-to-toml = { workspace = true } codex-utils-rustls-provider = { workspace = true } diff --git a/codex-rs/app-server/README.md b/codex-rs/app-server/README.md index 15b7f9a5be87..a280162615e7 100644 --- a/codex-rs/app-server/README.md +++ b/codex-rs/app-server/README.md @@ -25,6 +25,7 @@ Supported transports: - stdio (`--listen stdio://`, default): newline-delimited JSON (JSONL) - websocket (`--listen ws://IP:PORT`): one JSON-RPC message per websocket text frame (**experimental / unsupported**) +- unix socket (`--listen unix://` or `--listen unix://PATH`): websocket frames over `$CODEX_HOME/app-server-control/app-server-control.sock` or a custom socket path without HTTP upgrade - off (`--listen off`): do not expose a local transport When running with `--listen ws://IP:PORT`, the same listener also serves basic HTTP health probes: @@ -35,6 +36,11 @@ When running with `--listen ws://IP:PORT`, the same listener also serves basic H Websocket transport is currently experimental and unsupported. Do not rely on it for production workloads. +The unix socket transport is intended for local app-server control-plane clients. `codex app-server proxy` +opens exactly one raw stream connection to `$CODEX_HOME/app-server-control/app-server-control.sock` +by default, or to `--sock PATH` when provided, and proxies bytes between that socket and stdin/stdout. +The socket uses websocket framing directly over the Unix socket, without an HTTP upgrade handshake. + Security note: - Loopback websocket listeners (`ws://127.0.0.1:PORT`) remain appropriate for localhost and SSH port-forwarding workflows. diff --git a/codex-rs/app-server/src/app_server_tracing.rs b/codex-rs/app-server/src/app_server_tracing.rs index 2118e7730051..6e8133740f94 100644 --- a/codex-rs/app-server/src/app_server_tracing.rs +++ b/codex-rs/app-server/src/app_server_tracing.rs @@ -23,7 +23,7 @@ use tracing::info_span; pub(crate) fn request_span( request: &JSONRPCRequest, - transport: AppServerTransport, + transport: &AppServerTransport, connection_id: ConnectionId, session: &ConnectionSessionState, ) -> Span { @@ -82,9 +82,10 @@ pub(crate) fn typed_request_span( span } -fn transport_name(transport: AppServerTransport) -> &'static str { +fn transport_name(transport: &AppServerTransport) -> &'static str { match transport { AppServerTransport::Stdio => "stdio", + AppServerTransport::UnixSocket { .. } => "unix_socket", AppServerTransport::WebSocket { .. } => "websocket", AppServerTransport::Off => "off", } diff --git a/codex-rs/app-server/src/lib.rs b/codex-rs/app-server/src/lib.rs index a2f35305ae75..1c4f094335c3 100644 --- a/codex-rs/app-server/src/lib.rs +++ b/codex-rs/app-server/src/lib.rs @@ -31,6 +31,7 @@ use crate::transport::OutboundConnectionState; use crate::transport::TransportEvent; use crate::transport::auth::policy_from_settings; use crate::transport::route_outgoing_envelope; +use crate::transport::start_control_socket_acceptor; use crate::transport::start_remote_control; use crate::transport::start_stdio_connection; use crate::transport::start_websocket_acceptor; @@ -93,6 +94,7 @@ mod transport; pub use crate::error_code::INPUT_TOO_LARGE_ERROR_CODE; pub use crate::error_code::INVALID_PARAMS_ERROR_CODE; pub use crate::transport::AppServerTransport; +pub use crate::transport::app_server_control_socket_path; pub use crate::transport::auth::AppServerWebsocketAuthArgs; pub use crate::transport::auth::AppServerWebsocketAuthSettings; pub use crate::transport::auth::WebsocketAuthCliMode; @@ -542,7 +544,7 @@ pub async fn run_main_with_transport( let graceful_signal_restart_enabled = !single_client_mode; let mut app_server_client_name_rx = None; - match transport { + match &transport { AppServerTransport::Stdio => { let (stdio_client_name_tx, stdio_client_name_rx) = oneshot::channel::(); app_server_client_name_rx = Some(stdio_client_name_rx); @@ -553,9 +555,18 @@ pub async fn run_main_with_transport( ) .await?; } + AppServerTransport::UnixSocket { socket_path } => { + let accept_handle = start_control_socket_acceptor( + socket_path.clone(), + transport_event_tx.clone(), + transport_shutdown_token.clone(), + ) + .await?; + transport_accept_handles.push(accept_handle); + } AppServerTransport::WebSocket { bind_address } => { let accept_handle = start_websocket_acceptor( - bind_address, + *bind_address, transport_event_tx.clone(), transport_shutdown_token.clone(), policy_from_settings(&auth)?, @@ -660,7 +671,7 @@ pub async fn run_main_with_transport( config_warnings, session_source, auth_manager, - rpc_transport: analytics_rpc_transport(transport), + rpc_transport: analytics_rpc_transport(&transport), remote_control_handle: Some(remote_control_handle), })); let mut thread_created_rx = processor.thread_created_receiver(); @@ -772,7 +783,7 @@ pub async fn run_main_with_transport( .process_request( connection_id, request, - transport, + &transport, Arc::clone(&connection_state.session), ) .await; @@ -892,12 +903,12 @@ pub async fn run_main_with_transport( Ok(()) } -fn analytics_rpc_transport(transport: AppServerTransport) -> AppServerRpcTransport { +fn analytics_rpc_transport(transport: &AppServerTransport) -> AppServerRpcTransport { match transport { AppServerTransport::Stdio => AppServerRpcTransport::Stdio, - AppServerTransport::WebSocket { .. } | AppServerTransport::Off => { - AppServerRpcTransport::Websocket - } + AppServerTransport::UnixSocket { .. } + | AppServerTransport::WebSocket { .. } + | AppServerTransport::Off => AppServerRpcTransport::Websocket, } } diff --git a/codex-rs/app-server/src/main.rs b/codex-rs/app-server/src/main.rs index 069227070e1f..e3791609336e 100644 --- a/codex-rs/app-server/src/main.rs +++ b/codex-rs/app-server/src/main.rs @@ -17,7 +17,7 @@ const DISABLE_MANAGED_CONFIG_ENV_VAR: &str = "CODEX_APP_SERVER_DISABLE_MANAGED_C #[derive(Debug, Parser)] struct AppServerArgs { /// Transport endpoint URL. Supported values: `stdio://` (default), - /// `ws://IP:PORT`, `off`. + /// `unix://`, `unix://PATH`, `ws://IP:PORT`, `off`. #[arg( long = "listen", value_name = "URL", diff --git a/codex-rs/app-server/src/message_processor.rs b/codex-rs/app-server/src/message_processor.rs index 57fa6e21e0c7..48e2aa6a1459 100644 --- a/codex-rs/app-server/src/message_processor.rs +++ b/codex-rs/app-server/src/message_processor.rs @@ -362,7 +362,7 @@ impl MessageProcessor { self: &Arc, connection_id: ConnectionId, request: JSONRPCRequest, - transport: AppServerTransport, + transport: &AppServerTransport, session: Arc, ) { let request_method = request.method.as_str(); diff --git a/codex-rs/app-server/src/message_processor/tracing_tests.rs b/codex-rs/app-server/src/message_processor/tracing_tests.rs index e0a1ed4bc436..666ed013be32 100644 --- a/codex-rs/app-server/src/message_processor/tracing_tests.rs +++ b/codex-rs/app-server/src/message_processor/tracing_tests.rs @@ -187,7 +187,7 @@ impl TracingHarness { .process_request( TEST_CONNECTION_ID, request, - AppServerTransport::Stdio, + &AppServerTransport::Stdio, Arc::clone(&self.session), ) .await; @@ -210,7 +210,7 @@ impl TracingHarness { .process_request( TEST_CONNECTION_ID, request, - AppServerTransport::Stdio, + &AppServerTransport::Stdio, Arc::clone(&self.session), ) .await; diff --git a/codex-rs/app-server/src/transport/mod.rs b/codex-rs/app-server/src/transport/mod.rs index cb3510da7726..22e7a80a5deb 100644 --- a/codex-rs/app-server/src/transport/mod.rs +++ b/codex-rs/app-server/src/transport/mod.rs @@ -10,9 +10,12 @@ use crate::outgoing_message::QueuedOutgoingMessage; use codex_app_server_protocol::JSONRPCErrorError; use codex_app_server_protocol::JSONRPCMessage; use codex_app_server_protocol::ServerRequest; +use codex_core::config::find_codex_home; +use codex_utils_absolute_path::AbsolutePathBuf; use std::collections::HashMap; use std::collections::HashSet; use std::net::SocketAddr; +use std::path::Path; use std::str::FromStr; use std::sync::Arc; use std::sync::RwLock; @@ -31,16 +34,32 @@ pub(crate) const CHANNEL_CAPACITY: usize = 128; mod remote_control; mod stdio; +mod unix_socket; +#[cfg(test)] +mod unix_socket_tests; mod websocket; pub(crate) use remote_control::RemoteControlHandle; pub(crate) use remote_control::start_remote_control; pub(crate) use stdio::start_stdio_connection; +pub(crate) use unix_socket::start_control_socket_acceptor; pub(crate) use websocket::start_websocket_acceptor; -#[derive(Clone, Copy, Debug, Eq, PartialEq)] +const APP_SERVER_CONTROL_SOCKET_DIR_NAME: &str = "app-server-control"; +const APP_SERVER_CONTROL_SOCKET_FILE_NAME: &str = "app-server-control.sock"; + +pub fn app_server_control_socket_path(codex_home: &Path) -> std::io::Result { + AbsolutePathBuf::from_absolute_path( + codex_home + .join(APP_SERVER_CONTROL_SOCKET_DIR_NAME) + .join(APP_SERVER_CONTROL_SOCKET_FILE_NAME), + ) +} + +#[derive(Clone, Debug, Eq, PartialEq)] pub enum AppServerTransport { Stdio, + UnixSocket { socket_path: AbsolutePathBuf }, WebSocket { bind_address: SocketAddr }, Off, } @@ -48,6 +67,7 @@ pub enum AppServerTransport { #[derive(Debug, Clone, Eq, PartialEq)] pub enum AppServerTransportParseError { UnsupportedListenUrl(String), + InvalidUnixSocketPath { listen_url: String, message: String }, InvalidWebSocketListenUrl(String), } @@ -56,7 +76,14 @@ impl std::fmt::Display for AppServerTransportParseError { match self { AppServerTransportParseError::UnsupportedListenUrl(listen_url) => write!( f, - "unsupported --listen URL `{listen_url}`; expected `stdio://`, `ws://IP:PORT`, or `off`" + "unsupported --listen URL `{listen_url}`; expected `stdio://`, `unix://`, `unix://PATH`, `ws://IP:PORT`, or `off`" + ), + AppServerTransportParseError::InvalidUnixSocketPath { + listen_url, + message, + } => write!( + f, + "invalid unix socket --listen URL `{listen_url}`; failed to resolve socket path: {message}" ), AppServerTransportParseError::InvalidWebSocketListenUrl(listen_url) => write!( f, @@ -76,6 +103,31 @@ impl AppServerTransport { return Ok(Self::Stdio); } + if let Some(raw_socket_path) = listen_url.strip_prefix("unix://") { + let socket_path = if raw_socket_path.is_empty() { + let codex_home = find_codex_home().map_err(|err| { + AppServerTransportParseError::InvalidUnixSocketPath { + listen_url: listen_url.to_string(), + message: format!("failed to resolve CODEX_HOME: {err}"), + } + })?; + app_server_control_socket_path(&codex_home).map_err(|err| { + AppServerTransportParseError::InvalidUnixSocketPath { + listen_url: listen_url.to_string(), + message: err.to_string(), + } + })? + } else { + AbsolutePathBuf::relative_to_current_dir(raw_socket_path).map_err(|err| { + AppServerTransportParseError::InvalidUnixSocketPath { + listen_url: listen_url.to_string(), + message: err.to_string(), + } + })? + }; + return Ok(Self::UnixSocket { socket_path }); + } + if listen_url == "off" { return Ok(Self::Off); } diff --git a/codex-rs/app-server/src/transport/unix_socket.rs b/codex-rs/app-server/src/transport/unix_socket.rs new file mode 100644 index 000000000000..3075676dacbd --- /dev/null +++ b/codex-rs/app-server/src/transport/unix_socket.rs @@ -0,0 +1,161 @@ +use std::io::ErrorKind; +use std::io::Result as IoResult; +use std::path::Path; + +use super::TransportEvent; +use crate::transport::websocket::run_websocket_connection; +use codex_uds::UnixListener; +use codex_uds::UnixStream; +use codex_utils_absolute_path::AbsolutePathBuf; +use futures::StreamExt; +use tokio::sync::mpsc; +use tokio::task::JoinHandle; +use tokio::time::Duration; +use tokio_tungstenite::WebSocketStream; +use tokio_tungstenite::tungstenite::protocol::Role; +use tokio_util::sync::CancellationToken; +use tracing::error; +use tracing::info; +use tracing::warn; + +#[cfg(unix)] +const CONTROL_SOCKET_MODE: u32 = 0o600; + +pub(crate) async fn start_control_socket_acceptor( + socket_path: AbsolutePathBuf, + transport_event_tx: mpsc::Sender, + shutdown_token: CancellationToken, +) -> IoResult> { + prepare_control_socket_path(socket_path.as_path()).await?; + let listener = UnixListener::bind(socket_path.as_path()).await?; + let socket_guard = ControlSocketFileGuard { socket_path }; + set_control_socket_permissions(socket_guard.socket_path.as_path()).await?; + info!( + socket_path = %socket_guard.socket_path.display(), + "app-server control socket listening" + ); + + Ok(tokio::spawn(run_control_socket_acceptor( + listener, + transport_event_tx, + shutdown_token, + socket_guard, + ))) +} + +async fn run_control_socket_acceptor( + mut listener: UnixListener, + transport_event_tx: mpsc::Sender, + shutdown_token: CancellationToken, + socket_guard: ControlSocketFileGuard, +) { + let _socket_guard = socket_guard; + loop { + let stream = tokio::select! { + _ = shutdown_token.cancelled() => { + break; + } + result = listener.accept() => { + match result { + Ok(stream) => stream, + Err(err) => { + if matches!( + err.kind(), + ErrorKind::ConnectionAborted | ErrorKind::ConnectionReset | ErrorKind::Interrupted + ) { + warn!("recoverable control socket accept error: {err}"); + continue; + } + error!("control socket accept error: {err}"); + tokio::time::sleep(Duration::from_secs(1)).await; + continue; + } + } + } + }; + + let transport_event_tx = transport_event_tx.clone(); + tokio::spawn(async move { + let websocket_stream = + WebSocketStream::from_raw_socket(stream, Role::Server, None).await; + let (websocket_writer, websocket_reader) = websocket_stream.split(); + run_websocket_connection(websocket_writer, websocket_reader, transport_event_tx).await; + }); + } + info!("control socket acceptor shutting down"); +} + +async fn prepare_control_socket_path(socket_path: &Path) -> IoResult<()> { + if let Some(parent) = socket_path.parent() { + codex_uds::prepare_private_socket_directory(parent).await?; + } + + match UnixStream::connect(socket_path).await { + Ok(_stream) => { + return Err(std::io::Error::new( + ErrorKind::AddrInUse, + format!( + "app-server control socket is already in use at {}", + socket_path.display() + ), + )); + } + Err(err) if err.kind() == ErrorKind::NotFound => return Ok(()), + Err(err) if err.kind() == ErrorKind::ConnectionRefused => {} + Err(err) => { + if !socket_path.exists() { + return Ok(()); + } + return Err(err); + } + } + + if !socket_path.try_exists()? { + return Ok(()); + } + + if !codex_uds::is_stale_socket_path(socket_path).await? { + return Err(std::io::Error::new( + ErrorKind::AlreadyExists, + format!( + "app-server control socket path exists and is not a socket: {}", + socket_path.display() + ), + )); + } + tokio::fs::remove_file(socket_path).await +} + +#[cfg(unix)] +async fn set_control_socket_permissions(socket_path: &Path) -> IoResult<()> { + use std::os::unix::fs::PermissionsExt; + + tokio::fs::set_permissions( + socket_path, + std::fs::Permissions::from_mode(CONTROL_SOCKET_MODE), + ) + .await +} + +#[cfg(not(unix))] +async fn set_control_socket_permissions(_socket_path: &Path) -> IoResult<()> { + Ok(()) +} + +struct ControlSocketFileGuard { + socket_path: AbsolutePathBuf, +} + +impl Drop for ControlSocketFileGuard { + fn drop(&mut self) { + if let Err(err) = std::fs::remove_file(self.socket_path.as_path()) + && err.kind() != ErrorKind::NotFound + { + warn!( + socket_path = %self.socket_path.display(), + %err, + "failed to remove app-server control socket file" + ); + } + } +} diff --git a/codex-rs/app-server/src/transport/unix_socket_tests.rs b/codex-rs/app-server/src/transport/unix_socket_tests.rs new file mode 100644 index 000000000000..c2f7a7d3530d --- /dev/null +++ b/codex-rs/app-server/src/transport/unix_socket_tests.rs @@ -0,0 +1,199 @@ +use super::AppServerTransport; +use super::CHANNEL_CAPACITY; +use super::TransportEvent; +use super::app_server_control_socket_path; +use super::start_control_socket_acceptor; +use codex_app_server_protocol::JSONRPCMessage; +use codex_app_server_protocol::JSONRPCNotification; +use codex_core::config::find_codex_home; +use codex_uds::UnixStream; +use codex_utils_absolute_path::AbsolutePathBuf; +use futures::SinkExt; +use futures::StreamExt; +use pretty_assertions::assert_eq; +use std::io::Result as IoResult; +use std::path::Path; +use tokio::sync::mpsc; +use tokio::time::Duration; +use tokio::time::timeout; +use tokio_tungstenite::WebSocketStream; +use tokio_tungstenite::tungstenite::Bytes; +use tokio_tungstenite::tungstenite::Message as WebSocketMessage; +use tokio_tungstenite::tungstenite::protocol::Role; +use tokio_util::sync::CancellationToken; + +#[test] +fn listen_unix_socket_parses_as_unix_socket_transport() { + assert_eq!( + AppServerTransport::from_listen_url("unix://"), + Ok(AppServerTransport::UnixSocket { + socket_path: default_control_socket_path() + }) + ); +} + +#[test] +fn listen_unix_socket_accepts_absolute_custom_path() { + assert_eq!( + AppServerTransport::from_listen_url("unix:///tmp/codex.sock"), + Ok(AppServerTransport::UnixSocket { + socket_path: absolute_path("/tmp/codex.sock") + }) + ); +} + +#[test] +fn listen_unix_socket_accepts_relative_custom_path() { + assert_eq!( + AppServerTransport::from_listen_url("unix://codex.sock"), + Ok(AppServerTransport::UnixSocket { + socket_path: AbsolutePathBuf::relative_to_current_dir("codex.sock") + .expect("relative path should resolve") + }) + ); +} + +#[tokio::test] +async fn control_socket_acceptor_forwards_websocket_text_messages_and_pings() { + let temp_dir = tempfile::TempDir::new().expect("temp dir"); + let socket_path = test_socket_path(temp_dir.path()); + let (transport_event_tx, mut transport_event_rx) = + mpsc::channel::(CHANNEL_CAPACITY); + let shutdown_token = CancellationToken::new(); + let accept_handle = start_control_socket_acceptor( + socket_path.clone(), + transport_event_tx, + shutdown_token.clone(), + ) + .await + .expect("control socket acceptor should start"); + + let stream = connect_to_socket(socket_path.as_path()) + .await + .expect("client should connect"); + let mut websocket = WebSocketStream::from_raw_socket(stream, Role::Client, None).await; + + let opened = timeout(Duration::from_secs(1), transport_event_rx.recv()) + .await + .expect("connection opened event should arrive") + .expect("connection opened event"); + let connection_id = match opened { + TransportEvent::ConnectionOpened { connection_id, .. } => connection_id, + _ => panic!("expected connection opened event"), + }; + + let notification = JSONRPCMessage::Notification(JSONRPCNotification { + method: "initialized".to_string(), + params: None, + }); + websocket + .send(WebSocketMessage::Text( + serde_json::to_string(¬ification) + .expect("notification should serialize") + .into(), + )) + .await + .expect("notification should send"); + + let incoming = timeout(Duration::from_secs(1), transport_event_rx.recv()) + .await + .expect("incoming message event should arrive") + .expect("incoming message event"); + assert_eq!( + match incoming { + TransportEvent::IncomingMessage { + connection_id: incoming_connection_id, + message, + } => (incoming_connection_id, message), + _ => panic!("expected incoming message event"), + }, + (connection_id, notification) + ); + + websocket + .send(WebSocketMessage::Ping(Bytes::from_static(b"check"))) + .await + .expect("ping should send"); + let pong = timeout(Duration::from_secs(1), websocket.next()) + .await + .expect("pong should arrive") + .expect("pong frame") + .expect("pong should be valid"); + assert_eq!(pong, WebSocketMessage::Pong(Bytes::from_static(b"check"))); + + websocket.close(None).await.expect("close should send"); + let closed = timeout(Duration::from_secs(1), transport_event_rx.recv()) + .await + .expect("connection closed event should arrive") + .expect("connection closed event"); + assert!(matches!( + closed, + TransportEvent::ConnectionClosed { + connection_id: closed_connection_id, + } if closed_connection_id == connection_id + )); + + shutdown_token.cancel(); + accept_handle.await.expect("acceptor should join"); + assert_socket_path_removed(socket_path.as_path()); +} + +#[cfg(unix)] +#[tokio::test] +async fn control_socket_file_is_private_after_bind() { + use std::os::unix::fs::PermissionsExt; + + let temp_dir = tempfile::TempDir::new().expect("temp dir"); + let socket_path = test_socket_path(temp_dir.path()); + let (transport_event_tx, _transport_event_rx) = + mpsc::channel::(CHANNEL_CAPACITY); + let shutdown_token = CancellationToken::new(); + let accept_handle = start_control_socket_acceptor( + socket_path.clone(), + transport_event_tx, + shutdown_token.clone(), + ) + .await + .expect("control socket acceptor should start"); + + let metadata = tokio::fs::metadata(socket_path.as_path()) + .await + .expect("socket metadata should exist"); + assert_eq!(metadata.permissions().mode() & 0o777, 0o600); + + shutdown_token.cancel(); + accept_handle.await.expect("acceptor should join"); +} + +fn absolute_path(path: &str) -> AbsolutePathBuf { + AbsolutePathBuf::from_absolute_path(path).expect("absolute path") +} + +fn default_control_socket_path() -> AbsolutePathBuf { + let codex_home = find_codex_home().expect("codex home"); + app_server_control_socket_path(&codex_home).expect("default control socket path") +} + +fn test_socket_path(temp_dir: &Path) -> AbsolutePathBuf { + AbsolutePathBuf::from_absolute_path( + temp_dir + .join("app-server-control") + .join("app-server-control.sock"), + ) + .expect("socket path should resolve") +} + +async fn connect_to_socket(socket_path: &Path) -> IoResult { + UnixStream::connect(socket_path).await +} + +#[cfg(unix)] +fn assert_socket_path_removed(socket_path: &Path) { + assert!(!socket_path.exists()); +} + +#[cfg(windows)] +fn assert_socket_path_removed(_socket_path: &Path) { + // uds_windows uses a regular filesystem path as its rendezvous point, + // but there is no Unix socket filesystem node to assert on. +} diff --git a/codex-rs/app-server/src/transport/websocket.rs b/codex-rs/app-server/src/transport/websocket.rs index 2647d7c76813..1840231c3c57 100644 --- a/codex-rs/app-server/src/transport/websocket.rs +++ b/codex-rs/app-server/src/transport/websocket.rs @@ -11,10 +11,10 @@ use crate::outgoing_message::ConnectionId; use crate::outgoing_message::QueuedOutgoingMessage; use axum::Router; use axum::body::Body; +use axum::body::Bytes; use axum::extract::ConnectInfo; use axum::extract::State; -use axum::extract::ws::Message as WebSocketMessage; -use axum::extract::ws::WebSocket; +use axum::extract::ws::Message as AxumWebSocketMessage; use axum::extract::ws::WebSocketUpgrade; use axum::http::HeaderMap; use axum::http::Request; @@ -37,6 +37,7 @@ use std::sync::Arc; use tokio::net::TcpListener; use tokio::sync::mpsc; use tokio::task::JoinHandle; +use tokio_tungstenite::tungstenite::Message as TungsteniteWebSocketMessage; use tokio_util::sync::CancellationToken; use tracing::error; use tracing::info; @@ -112,11 +113,12 @@ async fn websocket_upgrade_handler( ); return (err.status_code(), err.message()).into_response(); } - let connection_id = next_connection_id(); info!(%peer_addr, "websocket client connected"); websocket .on_upgrade(move |stream| async move { - run_websocket_connection(connection_id, stream, state.transport_event_tx).await; + let (websocket_writer, websocket_reader) = stream.split(); + run_websocket_connection(websocket_writer, websocket_reader, state.transport_event_tx) + .await; }) .into_response() } @@ -162,11 +164,16 @@ pub(crate) async fn start_websocket_acceptor( })) } -async fn run_websocket_connection( - connection_id: ConnectionId, - websocket_stream: WebSocket, +pub(crate) async fn run_websocket_connection( + websocket_writer: impl futures::sink::Sink + Send + 'static, + websocket_reader: impl futures::stream::Stream> + Send + 'static, transport_event_tx: mpsc::Sender, -) { +) where + M: AppServerWebSocketMessage + Send + 'static, + SinkError: Send + 'static, + StreamError: std::fmt::Display + Send + 'static, +{ + let connection_id = next_connection_id(); let (writer_tx, writer_rx) = mpsc::channel::(CHANNEL_CAPACITY); let writer_tx_for_reader = writer_tx.clone(); let disconnect_token = CancellationToken::new(); @@ -183,9 +190,7 @@ async fn run_websocket_connection( return; } - let (websocket_writer, websocket_reader) = websocket_stream.split(); - let (writer_control_tx, writer_control_rx) = - mpsc::channel::(CHANNEL_CAPACITY); + let (writer_control_tx, writer_control_rx) = mpsc::channel::(CHANNEL_CAPACITY); let mut outbound_task = tokio::spawn(run_websocket_outbound_loop( websocket_writer, writer_rx, @@ -217,12 +222,74 @@ async fn run_websocket_connection( .await; } -async fn run_websocket_outbound_loop( - mut websocket_writer: futures::stream::SplitSink, +pub(crate) enum IncomingWebSocketMessage { + Text(String), + Binary, + Ping(Bytes), + Pong, + Close, +} + +/// Converts concrete WebSocket message types into the small message surface the +/// app-server transport needs, and constructs the only outbound frames it +/// sends directly. +pub(crate) trait AppServerWebSocketMessage: Sized { + fn text(text: String) -> Self; + fn pong(payload: Bytes) -> Self; + fn into_incoming(self) -> Option; +} + +impl AppServerWebSocketMessage for AxumWebSocketMessage { + fn text(text: String) -> Self { + Self::Text(text.into()) + } + + fn pong(payload: Bytes) -> Self { + Self::Pong(payload) + } + + fn into_incoming(self) -> Option { + Some(match self { + Self::Text(text) => IncomingWebSocketMessage::Text(text.to_string()), + Self::Binary(_) => IncomingWebSocketMessage::Binary, + Self::Ping(payload) => IncomingWebSocketMessage::Ping(payload), + Self::Pong(_) => IncomingWebSocketMessage::Pong, + Self::Close(_) => IncomingWebSocketMessage::Close, + }) + } +} + +impl AppServerWebSocketMessage for TungsteniteWebSocketMessage { + fn text(text: String) -> Self { + Self::Text(text.into()) + } + + fn pong(payload: Bytes) -> Self { + Self::Pong(payload) + } + + fn into_incoming(self) -> Option { + Some(match self { + Self::Text(text) => IncomingWebSocketMessage::Text(text.to_string()), + Self::Binary(_) => IncomingWebSocketMessage::Binary, + Self::Ping(payload) => IncomingWebSocketMessage::Ping(payload), + Self::Pong(_) => IncomingWebSocketMessage::Pong, + Self::Close(_) => IncomingWebSocketMessage::Close, + Self::Frame(_) => return None, + }) + } +} + +async fn run_websocket_outbound_loop( + websocket_writer: impl futures::sink::Sink + Send + 'static, mut writer_rx: mpsc::Receiver, - mut writer_control_rx: mpsc::Receiver, + mut writer_control_rx: mpsc::Receiver, disconnect_token: CancellationToken, -) { +) where + M: AppServerWebSocketMessage + Send + 'static, + SinkError: Send + 'static, +{ + tokio::pin!(websocket_writer); loop { tokio::select! { _ = disconnect_token.cancelled() => { @@ -243,7 +310,7 @@ async fn run_websocket_outbound_loop( let Some(json) = serialize_outgoing_message(queued_message.message) else { continue; }; - if websocket_writer.send(WebSocketMessage::Text(json.into())).await.is_err() { + if websocket_writer.send(M::text(json)).await.is_err() { break; } if let Some(write_complete_tx) = queued_message.write_complete_tx { @@ -254,14 +321,18 @@ async fn run_websocket_outbound_loop( } } -async fn run_websocket_inbound_loop( - mut websocket_reader: futures::stream::SplitStream, +async fn run_websocket_inbound_loop( + websocket_reader: impl futures::stream::Stream> + Send + 'static, transport_event_tx: mpsc::Sender, writer_tx_for_reader: mpsc::Sender, - writer_control_tx: mpsc::Sender, + writer_control_tx: mpsc::Sender, connection_id: ConnectionId, disconnect_token: CancellationToken, -) { +) where + M: AppServerWebSocketMessage + Send + 'static, + StreamError: std::fmt::Display + Send + 'static, +{ + tokio::pin!(websocket_reader); loop { tokio::select! { _ = disconnect_token.cancelled() => { @@ -269,33 +340,37 @@ async fn run_websocket_inbound_loop( } incoming_message = websocket_reader.next() => { match incoming_message { - Some(Ok(WebSocketMessage::Text(text))) => { - if !forward_incoming_message( - &transport_event_tx, - &writer_tx_for_reader, - connection_id, - text.as_ref(), - ) - .await - { - break; - } - } - Some(Ok(WebSocketMessage::Ping(payload))) => { - match writer_control_tx.try_send(WebSocketMessage::Pong(payload)) { - Ok(()) => {} - Err(tokio::sync::mpsc::error::TrySendError::Closed(_)) => break, - Err(tokio::sync::mpsc::error::TrySendError::Full(_)) => { - warn!("websocket control queue full while replying to ping; closing connection"); + Some(Ok(message)) => match message.into_incoming() { + Some(IncomingWebSocketMessage::Text(text)) => { + if !forward_incoming_message( + &transport_event_tx, + &writer_tx_for_reader, + connection_id, + &text, + ) + .await + { break; } } - } - Some(Ok(WebSocketMessage::Pong(_))) => {} - Some(Ok(WebSocketMessage::Close(_))) | None => break, - Some(Ok(WebSocketMessage::Binary(_))) => { - warn!("dropping unsupported binary websocket message"); - } + Some(IncomingWebSocketMessage::Ping(payload)) => { + match writer_control_tx.try_send(M::pong(payload)) { + Ok(()) => {} + Err(tokio::sync::mpsc::error::TrySendError::Closed(_)) => break, + Err(tokio::sync::mpsc::error::TrySendError::Full(_)) => { + warn!("websocket control queue full while replying to ping; closing connection"); + break; + } + } + } + Some(IncomingWebSocketMessage::Pong) => {} + Some(IncomingWebSocketMessage::Close) => break, + Some(IncomingWebSocketMessage::Binary) => { + warn!("dropping unsupported binary websocket message"); + } + None => {} + }, + None => break, Some(Err(err)) => { warn!("websocket receive error: {err}"); break; diff --git a/codex-rs/cli/src/main.rs b/codex-rs/cli/src/main.rs index 9852a2cd5fa3..baba813b9a09 100644 --- a/codex-rs/cli/src/main.rs +++ b/codex-rs/cli/src/main.rs @@ -28,6 +28,7 @@ use codex_tui::AppExitInfo; use codex_tui::Cli as TuiCli; use codex_tui::ExitReason; use codex_tui::UpdateAction; +use codex_utils_absolute_path::AbsolutePathBuf; use codex_utils_cli::CliConfigOverrides; use owo_colors::OwoColorize; use std::io::IsTerminal; @@ -392,7 +393,7 @@ struct AppServerCommand { subcommand: Option, /// Transport endpoint URL. Supported values: `stdio://` (default), - /// `ws://IP:PORT`, `off`. + /// `unix://`, `unix://PATH`, `ws://IP:PORT`, `off`. #[arg( long = "listen", value_name = "URL", @@ -436,6 +437,9 @@ struct ExecServerCommand { #[derive(Debug, clap::Subcommand)] #[allow(clippy::enum_variant_names)] enum AppServerSubcommand { + /// Proxy stdio bytes to the running app-server control socket. + Proxy(AppServerProxyCommand), + /// [experimental] Generate TypeScript bindings for the app server protocol. GenerateTs(GenerateTsCommand), @@ -447,6 +451,13 @@ enum AppServerSubcommand { GenerateInternalJsonSchema(GenerateInternalJsonSchemaCommand), } +#[derive(Debug, Args)] +struct AppServerProxyCommand { + /// Path to the app-server Unix domain socket to connect to. + #[arg(long = "sock", value_name = "SOCKET_PATH", value_parser = parse_socket_path)] + socket_path: Option, +} + #[derive(Debug, Args)] struct GenerateTsCommand { /// Output directory where .ts files will be written @@ -483,8 +494,13 @@ struct GenerateInternalJsonSchemaCommand { #[derive(Debug, Parser)] struct StdioToUdsCommand { /// Path to the Unix domain socket to connect to. - #[arg(value_name = "SOCKET_PATH")] - socket_path: PathBuf, + #[arg(value_name = "SOCKET_PATH", value_parser = parse_socket_path)] + socket_path: AbsolutePathBuf, +} + +fn parse_socket_path(raw: &str) -> Result { + AbsolutePathBuf::relative_to_current_dir(raw) + .map_err(|err| format!("failed to resolve socket path `{raw}`: {err}")) } fn format_exit_messages(exit_info: AppExitInfo, color_enabled: bool) -> Vec { @@ -803,6 +819,16 @@ async fn cli_main(arg0_paths: Arg0DispatchPaths) -> anyhow::Result<()> { ) .await?; } + Some(AppServerSubcommand::Proxy(proxy_cli)) => { + let socket_path = match proxy_cli.socket_path { + Some(socket_path) => socket_path, + None => { + let codex_home = find_codex_home()?; + codex_app_server::app_server_control_socket_path(&codex_home)? + } + }; + codex_stdio_to_uds::run(socket_path.as_path()).await?; + } Some(AppServerSubcommand::GenerateTs(gen_cli)) => { let options = codex_app_server_protocol::GenerateTsOptions { experimental_api: gen_cli.experimental, @@ -1407,6 +1433,7 @@ fn reject_remote_mode_for_app_server_subcommand( ) -> anyhow::Result<()> { let subcommand_name = match subcommand { None => "app-server", + Some(AppServerSubcommand::Proxy(_)) => "app-server proxy", Some(AppServerSubcommand::GenerateTs(_)) => "app-server generate-ts", Some(AppServerSubcommand::GenerateJsonSchema(_)) => "app-server generate-json-schema", Some(AppServerSubcommand::GenerateInternalJsonSchema(_)) => { @@ -1723,6 +1750,12 @@ mod tests { app_server } + fn default_app_server_socket_path() -> AbsolutePathBuf { + let codex_home = find_codex_home().expect("codex home"); + codex_app_server::app_server_control_socket_path(&codex_home) + .expect("default app-server socket path") + } + #[test] fn debug_prompt_input_parses_prompt_and_images() { let cli = MultitoolCli::try_parse_from([ @@ -2198,6 +2231,32 @@ mod tests { ); } + #[test] + fn app_server_listen_unix_socket_url_parses() { + let app_server = + app_server_from_args(["codex", "app-server", "--listen", "unix://"].as_ref()); + assert_eq!( + app_server.listen, + codex_app_server::AppServerTransport::UnixSocket { + socket_path: default_app_server_socket_path() + } + ); + } + + #[test] + fn app_server_listen_unix_socket_path_parses() { + let app_server = app_server_from_args( + ["codex", "app-server", "--listen", "unix:///tmp/codex.sock"].as_ref(), + ); + assert_eq!( + app_server.listen, + codex_app_server::AppServerTransport::UnixSocket { + socket_path: AbsolutePathBuf::from_absolute_path("/tmp/codex.sock") + .expect("absolute path should parse") + } + ); + } + #[test] fn app_server_listen_off_parses() { let app_server = app_server_from_args(["codex", "app-server", "--listen", "off"].as_ref()); @@ -2211,6 +2270,45 @@ mod tests { assert!(parse_result.is_err()); } + #[test] + fn app_server_proxy_subcommand_parses() { + let app_server = app_server_from_args(["codex", "app-server", "proxy"].as_ref()); + assert!(matches!( + app_server.subcommand, + Some(AppServerSubcommand::Proxy(AppServerProxyCommand { + socket_path: None + })) + )); + } + + #[test] + fn app_server_proxy_sock_path_parses() { + let app_server = + app_server_from_args(["codex", "app-server", "proxy", "--sock", "codex.sock"].as_ref()); + let Some(AppServerSubcommand::Proxy(proxy)) = app_server.subcommand else { + panic!("expected proxy subcommand"); + }; + assert_eq!( + proxy.socket_path, + Some( + AbsolutePathBuf::relative_to_current_dir("codex.sock") + .expect("relative path should resolve") + ) + ); + } + + #[test] + fn reject_remote_auth_token_env_for_app_server_proxy() { + let subcommand = AppServerSubcommand::Proxy(AppServerProxyCommand { socket_path: None }); + let err = reject_remote_mode_for_app_server_subcommand( + /*remote*/ None, + Some("CODEX_REMOTE_AUTH_TOKEN"), + Some(&subcommand), + ) + .expect_err("app-server proxy should reject --remote-auth-token-env"); + assert!(err.to_string().contains("app-server proxy")); + } + #[test] fn app_server_capability_token_flags_parse() { let app_server = app_server_from_args(