From d81ae45c1720524af079ca941821267a86e20368 Mon Sep 17 00:00:00 2001 From: George Miao Date: Tue, 18 Nov 2025 23:37:29 -0500 Subject: [PATCH 1/5] feat(ws): add `Config` --- compio-io/src/compat/sync_stream.rs | 8 +- compio-io/src/read/buf.rs | 2 +- compio-io/src/util/mod.rs | 2 +- compio-io/src/write/buf.rs | 2 +- compio-ws/src/lib.rs | 158 ++++++++++++++++++++++++---- compio-ws/src/tls.rs | 37 +++---- 6 files changed, 159 insertions(+), 50 deletions(-) diff --git a/compio-io/src/compat/sync_stream.rs b/compio-io/src/compat/sync_stream.rs index c70be850..1c0e4a49 100644 --- a/compio-io/src/compat/sync_stream.rs +++ b/compio-io/src/compat/sync_stream.rs @@ -43,20 +43,20 @@ pub struct SyncStream { } impl SyncStream { - // 64MB max + // 64MiB max const DEFAULT_MAX_BUFFER: usize = 64 * 1024 * 1024; /// Creates a new `SyncStream` with default buffer sizes. /// - /// - Base capacity: 8KB - /// - Max buffer size: 64MB + /// - Base capacity: 8KiB + /// - Max buffer size: 64MiB pub fn new(stream: S) -> Self { Self::with_capacity(DEFAULT_BUF_SIZE, stream) } /// Creates a new `SyncStream` with a custom base capacity. /// - /// The maximum buffer size defaults to 64MB. + /// The maximum buffer size defaults to 64MiB. pub fn with_capacity(base_capacity: usize, stream: S) -> Self { Self::with_limits(base_capacity, Self::DEFAULT_MAX_BUFFER, stream) } diff --git a/compio-io/src/read/buf.rs b/compio-io/src/read/buf.rs index 9cc8ddeb..d38f304e 100644 --- a/compio-io/src/read/buf.rs +++ b/compio-io/src/read/buf.rs @@ -55,7 +55,7 @@ pub struct BufReader { impl BufReader { /// Creates a new `BufReader` with a default buffer capacity. The default is - /// currently 8 KB, but may change in the future. + /// currently 8 KiB, but may change in the future. pub fn new(reader: R) -> Self { Self::with_capacity(DEFAULT_BUF_SIZE, reader) } diff --git a/compio-io/src/util/mod.rs b/compio-io/src/util/mod.rs index 733d788e..216cc256 100644 --- a/compio-io/src/util/mod.rs +++ b/compio-io/src/util/mod.rs @@ -29,7 +29,7 @@ pub use split::Splittable; /// /// This is an asynchronous version of [`std::io::copy`][std]. /// -/// A heap-allocated copy buffer with 8 KB is created to take data from the +/// A heap-allocated copy buffer with 8 KiB is created to take data from the /// reader to the writer. pub async fn copy(reader: &mut R, writer: &mut W) -> IoResult { let mut buf = Vec::with_capacity(DEFAULT_BUF_SIZE); diff --git a/compio-io/src/write/buf.rs b/compio-io/src/write/buf.rs index c73633a5..09d3cda5 100644 --- a/compio-io/src/write/buf.rs +++ b/compio-io/src/write/buf.rs @@ -32,7 +32,7 @@ pub struct BufWriter { impl BufWriter { /// Creates a new `BufWriter` with a default buffer capacity. The default is - /// currently 8 KB, but may change in the future. + /// currently 8 KiB, but may change in the future. pub fn new(writer: W) -> Self { Self::with_capacity(DEFAULT_BUF_SIZE, writer) } diff --git a/compio-ws/src/lib.rs b/compio-ws/src/lib.rs index 8274abb9..2b332433 100644 --- a/compio-ws/src/lib.rs +++ b/compio-ws/src/lib.rs @@ -25,6 +25,121 @@ mod tls; pub use tls::*; pub use tungstenite; +/// Configuration for compio-ws. +/// +/// ## API Interface +/// +/// `_with_config` functions in this crate accept `impl Into`, so +/// following are all valid: +/// - [`Config`] +/// - [`WebSocketConfig`] (use custom WebSocket config with default remaining +/// settings) +/// - [`None`] (use default value) +pub struct Config { + /// WebSocket configuration from tungstenite. + websocket: Option, + + /// Base buffer size + buffer_size_base: usize, + + /// Maximum buffer size + buffer_size_limit: usize, + + /// Disable Nagle's algorithm. This only affects + /// [`connect_async_with_config()`] and [`connect_async_tls_with_config()`]. + disable_nagle: bool, +} + +impl Config { + // 128 KiB, see . + const DEFAULT_BUF_SIZE: usize = 128 * 1024; + // 64 MiB, the same as [`SyncStream`]. + const DEFAULT_MAX_BUFFER: usize = 64 * 1024 * 1024; + + /// Creates a new `Config` with default settings. + pub fn new() -> Self { + Self { + websocket: None, + buffer_size_base: Self::DEFAULT_BUF_SIZE, + buffer_size_limit: Self::DEFAULT_MAX_BUFFER, + disable_nagle: false, + } + } + + /// Get the WebSocket configuration. + pub fn websocket_config(&self) -> Option<&WebSocketConfig> { + self.websocket.as_ref() + } + + /// Get the base buffer size. + pub fn buffer_size_base(&self) -> usize { + self.buffer_size_base + } + + /// Get the maximum buffer size. + pub fn buffer_size_limit(&self) -> usize { + self.buffer_size_limit + } + + /// Set custom base buffer size. + /// + /// Default to 128 KiB. + pub fn with_buffer_size_base(mut self, size: usize) -> Self { + self.buffer_size_base = size; + self + } + + /// Set custom maximum buffer size. + /// + /// Default to 64 MiB. + pub fn with_buffer_size_limit(mut self, size: usize) -> Self { + self.buffer_size_limit = size; + self + } + + /// Set custom buffer sizes. + /// + /// Default to 128 KiB for base and 64 MiB for limit. + pub fn with_buffer_sizes(mut self, base: usize, limit: usize) -> Self { + self.buffer_size_base = base; + self.buffer_size_limit = limit; + self + } + + /// Disable Nagle's algorithm, i.e. `set_nodelay(true)`. + /// + /// Default to `false`. If you don't know what the Nagle's algorithm is, + /// better leave it to `false`. + pub fn disable_nagle(mut self, disable: bool) -> Self { + self.disable_nagle = disable; + self + } +} + +impl Default for Config { + fn default() -> Self { + Self::new() + } +} + +impl From for Config { + fn from(config: WebSocketConfig) -> Self { + Self { + websocket: Some(config), + ..Default::default() + } + } +} + +impl From> for Config { + fn from(config: Option) -> Self { + Self { + websocket: config, + ..Default::default() + } + } +} + /// A WebSocket stream that works with compio. #[derive(Debug)] pub struct WebSocketStream { @@ -135,11 +250,10 @@ impl IntoInner for WebSocketStream { /// Accepts a new WebSocket connection with the provided stream. /// -/// This function will internally call `server::accept` to create a -/// handshake representation and returns a future representing the -/// resolution of the WebSocket handshake. The returned future will resolve -/// to either `WebSocketStream` or `Error` depending if it's successful -/// or not. +/// This function will internally create a handshake representation and returns +/// a future representing the resolution of the WebSocket handshake. The +/// returned future will resolve to either [`WebSocketStream`] or [`WsError`] +/// depending on if it's successful or not. /// /// This is typically used after a socket has been accepted from a /// `TcpListener`. That socket is then passed to this function to perform @@ -151,11 +265,10 @@ where accept_hdr_async(stream, NoCallback).await } -/// The same as `accept_async()` but the one can specify a websocket -/// configuration. Please refer to `accept_async()` for more details. +/// Similar to [`accept_async()`] but user can specify a [`Config`]. pub async fn accept_async_with_config( stream: S, - config: Option, + config: impl Into, ) -> Result, WsError> where S: AsyncRead + AsyncWrite, @@ -164,7 +277,7 @@ where } /// Accepts a new WebSocket connection with the provided stream. /// -/// This function does the same as `accept_async()` but accepts an extra +/// This function does the same as [`accept_async()`] but accepts an extra /// callback for header processing. The callback receives headers of the /// incoming requests and is able to add extra headers to the reply. pub async fn accept_hdr_async(stream: S, callback: C) -> Result, WsError> @@ -175,19 +288,21 @@ where accept_hdr_with_config_async(stream, callback, None).await } -/// The same as `accept_hdr_async()` but the one can specify a websocket -/// configuration. Please refer to `accept_hdr_async()` for more details. +/// Similar to [`accept_hdr_async()`] but user can specify a [`Config`]. pub async fn accept_hdr_with_config_async( stream: S, callback: C, - config: Option, + config: impl Into, ) -> Result, WsError> where S: AsyncRead + AsyncWrite, C: Callback, { - let sync_stream = SyncStream::with_capacity(128 * 1024, stream); - let mut handshake_result = tungstenite::accept_hdr_with_config(sync_stream, callback, config); + let config = config.into(); + let sync_stream = + SyncStream::with_limits(config.buffer_size_base, config.buffer_size_limit, stream); + let mut handshake_result = + tungstenite::accept_hdr_with_config(sync_stream, callback, config.websocket); loop { match handshake_result { @@ -223,7 +338,7 @@ where /// /// Internally, this creates a handshake representation and returns /// a future representing the resolution of the WebSocket handshake. The -/// returned future will resolve to either `WebSocketStream` or `Error` +/// returned future will resolve to either [`WebSocketStream`] or [`WsError`] /// depending on whether the handshake is successful. /// /// This is typically used for clients who have already established, for @@ -239,20 +354,21 @@ where client_async_with_config(request, stream, None).await } -/// The same as `client_async()` but the one can specify a websocket -/// configuration. Please refer to `client_async()` for more details. -pub async fn client_async_with_config( +/// Similar to [`client_async()`] but user can specify a [`Config`]. +async fn client_async_with_config( request: R, stream: S, - config: Option, + config: impl Into, ) -> Result<(WebSocketStream, tungstenite::handshake::client::Response), WsError> where R: IntoClientRequest, S: AsyncRead + AsyncWrite, { - let sync_stream = SyncStream::with_capacity(128 * 1024, stream); + let config = config.into(); + let sync_stream = + SyncStream::with_limits(config.buffer_size_base, config.buffer_size_limit, stream); let mut handshake_result = - tungstenite::client::client_with_config(request, sync_stream, config); + tungstenite::client::client_with_config(request, sync_stream, config.websocket); loop { match handshake_result { diff --git a/compio-ws/src/tls.rs b/compio-ws/src/tls.rs index 49abe2a0..86a3a328 100644 --- a/compio-ws/src/tls.rs +++ b/compio-ws/src/tls.rs @@ -10,7 +10,7 @@ use tungstenite::{ stream::Mode, }; -use crate::{WebSocketConfig, WebSocketStream, client_async_with_config}; +use crate::{Config, WebSocketStream, client_async_with_config}; mod encryption { #[cfg(feature = "native-tls")] @@ -180,13 +180,13 @@ where client_async_tls_with_config(request, stream, None, None).await } -/// The same as `client_async_tls()` but the one can specify a websocket +/// Similar to `client_async_tls()` but the one can specify a websocket /// configuration, and an optional connector. pub async fn client_async_tls_with_config( request: R, stream: S, connector: Option, - config: Option, + config: impl Into, ) -> Result<(WebSocketStream>, Response), Error> where R: IntoClientRequest, @@ -212,38 +212,31 @@ pub async fn connect_async( where R: IntoClientRequest, { - connect_async_with_config(request, None, false).await + connect_async_with_config(request, None).await } -/// The same as `connect_async()` but the one can specify a websocket -/// configuration. `disable_nagle` specifies if the Nagle's algorithm must be -/// disabled, i.e. `set_nodelay(true)`. If you don't know what the Nagle's -/// algorithm is, better leave it to `false`. +/// Similar to [`connect_async()`], but user can specify a [`Config`]. pub async fn connect_async_with_config( request: R, - config: Option, - disable_nagle: bool, + config: impl Into, ) -> Result<(WebSocketStream, Response), Error> where R: IntoClientRequest, { - connect_async_tls_with_config(request, config, disable_nagle, None).await + connect_async_tls_with_config(request, config, None).await } -/// The same as `connect_async()` but the one can specify a websocket -/// configuration, a TLS connector, and whether to disable Nagle's algorithm. -/// `disable_nagle` specifies if the Nagle's algorithm must be disabled, i.e. -/// `set_nodelay(true)`. If you don't know what the Nagle's algorithm is, better -/// leave it to `false`. +/// Similar to [`connect_async()`], but user can specify a [`Config`] and an +/// optional [`TlsConnector`]. pub async fn connect_async_tls_with_config( request: R, - config: Option, - disable_nagle: bool, + config: impl Into, connector: Option, ) -> Result<(WebSocketStream, Response), Error> where R: IntoClientRequest, { + let config = config.into(); let request: Request = request.into_client_request()?; // We don't check if it's an IPv6 address because `std` handles it internally. @@ -253,10 +246,10 @@ where .ok_or(Error::Url(tungstenite::error::UrlError::NoHostName))?; let port = port(&request)?; - let socket = - TcpStream::connect_with_options((domain, port), TcpOpts::new().nodelay(disable_nagle)) - .await - .map_err(Error::Io)?; + let opts = TcpOpts::new().nodelay(config.disable_nagle); + let socket = TcpStream::connect_with_options((domain, port), opts) + .await + .map_err(Error::Io)?; client_async_tls_with_config(request, socket, connector, config).await } From 2cf487111f3cbf5746b753795bf0d52c5f444c54 Mon Sep 17 00:00:00 2001 From: George Miao Date: Tue, 18 Nov 2025 23:42:57 -0500 Subject: [PATCH 2/5] fix(ws): example --- compio-ws/examples/client_tls.rs | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/compio-ws/examples/client_tls.rs b/compio-ws/examples/client_tls.rs index 716bba5e..0e97e88c 100644 --- a/compio-ws/examples/client_tls.rs +++ b/compio-ws/examples/client_tls.rs @@ -92,7 +92,7 @@ async fn main() -> Result<(), Box> { let connector = None; let (mut websocket, _response) = - connect_async_tls_with_config("wss://127.0.0.1:9002", None, false, connector).await?; + connect_async_tls_with_config("wss://127.0.0.1:9002", None, connector).await?; println!("Successfully connected to WebSocket TLS server!"); println!(); From ed39db14a94553a3ee4a881ba6a9edae34a0ba40 Mon Sep 17 00:00:00 2001 From: Pop Date: Tue, 18 Nov 2025 23:44:43 -0500 Subject: [PATCH 3/5] fix(ws): client_async_with_config is private Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com> --- compio-ws/src/lib.rs | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/compio-ws/src/lib.rs b/compio-ws/src/lib.rs index 2b332433..5a454541 100644 --- a/compio-ws/src/lib.rs +++ b/compio-ws/src/lib.rs @@ -355,7 +355,7 @@ where } /// Similar to [`client_async()`] but user can specify a [`Config`]. -async fn client_async_with_config( +pub async fn client_async_with_config( request: R, stream: S, config: impl Into, From 0221d92f7e02e0d6674b2d4ec86e4c7d71f6fba6 Mon Sep 17 00:00:00 2001 From: Pop Date: Tue, 18 Nov 2025 23:45:12 -0500 Subject: [PATCH 4/5] style(ws): docs link Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com> --- compio-ws/src/tls.rs | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/compio-ws/src/tls.rs b/compio-ws/src/tls.rs index 86a3a328..8e0c4c78 100644 --- a/compio-ws/src/tls.rs +++ b/compio-ws/src/tls.rs @@ -180,7 +180,7 @@ where client_async_tls_with_config(request, stream, None, None).await } -/// Similar to `client_async_tls()` but the one can specify a websocket +/// Similar to [`client_async_tls()`] but the one can specify a websocket /// configuration, and an optional connector. pub async fn client_async_tls_with_config( request: R, From 246ec385049e802b287e338dc22d058a81f12b3b Mon Sep 17 00:00:00 2001 From: George Miao Date: Wed, 19 Nov 2025 00:04:51 -0500 Subject: [PATCH 5/5] refactor(tls): match Infallible --- .gitignore | 10 +++++++--- compio-tls/src/stream.rs | 12 +++--------- 2 files changed, 10 insertions(+), 12 deletions(-) diff --git a/.gitignore b/.gitignore index 14b97702..1714b007 100644 --- a/.gitignore +++ b/.gitignore @@ -1,7 +1,11 @@ /target /Cargo.lock -/.vscode /.cargo -.idea + /.direnv -.envrc +/.envrc + +# Editor directories +/.vscode +/.zed +/.idea diff --git a/compio-tls/src/stream.rs b/compio-tls/src/stream.rs index 00ef2fce..10b072f9 100644 --- a/compio-tls/src/stream.rs +++ b/compio-tls/src/stream.rs @@ -81,7 +81,7 @@ impl AsyncRead for TlsStream { slice.fill(MaybeUninit::new(0)); // SAFETY: The memory has been initialized let slice = - unsafe { std::slice::from_raw_parts_mut(slice.as_mut_ptr().cast(), slice.len()) }; + unsafe { std::slice::from_raw_parts_mut::(slice.as_mut_ptr().cast(), slice.len()) }; match &mut self.0 { #[cfg(feature = "native-tls")] TlsStreamInner::NativeTls(s) => loop { @@ -115,10 +115,7 @@ impl AsyncRead for TlsStream { BufResult(res, buf) } #[cfg(not(any(feature = "native-tls", feature = "rustls")))] - TlsStreamInner::None(f, ..) => { - let _slice: &mut [u8] = slice; - match *f {} - } + TlsStreamInner::None(f, ..) => match *f {}, } } } @@ -159,10 +156,7 @@ impl AsyncWrite for TlsStream { BufResult(res, buf) } #[cfg(not(any(feature = "native-tls", feature = "rustls")))] - TlsStreamInner::None(f, ..) => { - let _slice: &[u8] = slice; - match *f {} - } + TlsStreamInner::None(f, ..) => match *f {}, } }