diff --git a/compio-io/src/buffer.rs b/compio-io/src/buffer.rs index 61806a12..f07016e6 100644 --- a/compio-io/src/buffer.rs +++ b/compio-io/src/buffer.rs @@ -26,7 +26,7 @@ impl Inner { #[inline] fn reset(&mut self) { self.pos = 0; - unsafe { self.buf.set_len(0) }; + self.buf.clear(); } #[inline] @@ -34,6 +34,35 @@ impl Inner { &self.buf[self.pos..] } + pub fn reserve_exact(&mut self, additional: usize) { + self.buf.reserve_exact(additional); + } + + pub fn extend_from_slice(&mut self, data: &[u8]) { + self.buf.extend_from_slice(data); + } + + fn compact_to(&mut self, capacity: usize, max_capacity: usize) { + if self.pos > 0 && self.pos < self.buf.len() { + let buf_len = self.buf.len(); + let remaining = buf_len - self.pos; + self.buf.copy_within(self.pos..buf_len, 0); + + // SAFETY: We're setting the length to the amount of data we just moved. + // The data from 0..remaining is initialized (just moved from read_pos..buf_len) + unsafe { + self.buf.set_len(remaining); + } + self.pos = 0; + } else if self.pos >= self.buf.len() { + // All data consumed, reset buffer + self.reset(); + if self.buf.capacity() > max_capacity { + self.buf.shrink_to(capacity); + } + } + } + #[inline] pub(crate) fn into_slice(self) -> Slice { let pos = self.pos; @@ -138,6 +167,12 @@ impl Buffer { self.inner_mut().buf.reserve(additional); } + /// Compact the buffer to the given capacity, if the current capacity is + /// larger than the given maximum capacity. + pub fn compact_to(&mut self, capacity: usize, max_capacity: usize) { + self.inner_mut().compact_to(capacity, max_capacity); + } + /// Execute a funcition with ownership of the buffer, and restore the buffer /// afterwards pub async fn with(&mut self, func: F) -> IoResult @@ -175,7 +210,7 @@ impl Buffer { .await?; if written == 0 { return Err(std::io::Error::new( - std::io::ErrorKind::UnexpectedEof, + std::io::ErrorKind::WriteZero, "cannot flush all buffer data", )); } diff --git a/compio-io/src/compat.rs b/compio-io/src/compat/async_stream.rs similarity index 56% rename from compio-io/src/compat.rs rename to compio-io/src/compat/async_stream.rs index d87b7768..80987581 100644 --- a/compio-io/src/compat.rs +++ b/compio-io/src/compat/async_stream.rs @@ -1,180 +1,12 @@ -//! Compat wrappers for interop with other crates. - use std::{ fmt::Debug, - io::{self, BufRead, Read, Write}, + io::{self, BufRead}, mem::MaybeUninit, pin::Pin, task::{Context, Poll}, }; -use compio_buf::{BufResult, IntoInner, IoBuf, IoBufMut, SetBufInit}; - -use crate::{PinBoxFuture, buffer::Buffer, util::DEFAULT_BUF_SIZE}; - -/// A wrapper for [`AsyncRead`](crate::AsyncRead) + -/// [`AsyncWrite`](crate::AsyncWrite), providing sync traits impl. -/// -/// The sync methods will return [`io::ErrorKind::WouldBlock`] error if the -/// inner buffer needs more data. -#[derive(Debug)] -pub struct SyncStream { - stream: S, - eof: bool, - read_buffer: Buffer, - write_buffer: Buffer, -} - -impl SyncStream { - /// Create [`SyncStream`] with the stream and default buffer size. - pub fn new(stream: S) -> Self { - Self::with_capacity(DEFAULT_BUF_SIZE, stream) - } - - /// Create [`SyncStream`] with the stream and buffer size. - pub fn with_capacity(cap: usize, stream: S) -> Self { - Self { - stream, - eof: false, - read_buffer: Buffer::with_capacity(cap), - write_buffer: Buffer::with_capacity(cap), - } - } - - /// Get if the stream is at EOF. - pub fn is_eof(&self) -> bool { - self.eof - } - - /// Get the reference of the inner stream. - pub fn get_ref(&self) -> &S { - &self.stream - } - - /// Get the mutable reference of the inner stream. - pub fn get_mut(&mut self) -> &mut S { - &mut self.stream - } - - fn flush_impl(&mut self) -> io::Result<()> { - if !self.write_buffer.is_empty() { - Err(would_block("need to flush the write buffer")) - } else { - Ok(()) - } - } - - /// Pull some bytes from this source into the specified buffer. - pub fn read_buf_uninit(&mut self, buf: &mut [MaybeUninit]) -> io::Result { - let slice = self.fill_buf()?; - let amt = buf.len().min(slice.len()); - // SAFETY: the length is valid - buf[..amt] - .copy_from_slice(unsafe { std::slice::from_raw_parts(slice.as_ptr().cast(), amt) }); - self.consume(amt); - Ok(amt) - } -} - -impl Read for SyncStream { - fn read(&mut self, buf: &mut [u8]) -> io::Result { - let mut slice = self.fill_buf()?; - slice.read(buf).inspect(|res| { - self.consume(*res); - }) - } - - #[cfg(feature = "read_buf")] - fn read_buf(&mut self, mut buf: io::BorrowedCursor<'_>) -> io::Result<()> { - let mut slice = self.fill_buf()?; - let old_written = buf.written(); - slice.read_buf(buf.reborrow())?; - let len = buf.written() - old_written; - self.consume(len); - Ok(()) - } -} - -impl BufRead for SyncStream { - fn fill_buf(&mut self) -> io::Result<&[u8]> { - if self.read_buffer.all_done() { - self.read_buffer.reset(); - } - - if self.read_buffer.slice().is_empty() && !self.eof { - return Err(would_block("need to fill the read buffer")); - } - - Ok(self.read_buffer.slice()) - } - - fn consume(&mut self, amt: usize) { - self.read_buffer.advance(amt); - } -} - -impl Write for SyncStream { - fn write(&mut self, buf: &[u8]) -> io::Result { - if self.write_buffer.need_flush() { - self.flush_impl()?; - } - - let written = self.write_buffer.with_sync(|mut inner| { - let len = buf.len().min(inner.buf_capacity() - inner.buf_len()); - unsafe { - std::ptr::copy_nonoverlapping( - buf.as_ptr(), - inner.as_buf_mut_ptr().add(inner.buf_len()), - len, - ); - inner.set_buf_init(inner.buf_len() + len); - } - BufResult(Ok(len), inner) - })?; - - Ok(written) - } - - fn flush(&mut self) -> io::Result<()> { - // Related PR: - // https://github.com/sfackler/rust-openssl/pull/1922 - // After this PR merged, we can use self.flush_impl() - Ok(()) - } -} - -fn would_block(msg: &str) -> io::Error { - io::Error::new(io::ErrorKind::WouldBlock, msg) -} - -impl SyncStream { - /// Fill the read buffer. - pub async fn fill_read_buf(&mut self) -> io::Result { - let stream = &mut self.stream; - let len = self - .read_buffer - .with(|b| async move { - let len = b.buf_len(); - let b = b.slice(len..); - stream.read(b).await.into_inner() - }) - .await?; - if len == 0 { - self.eof = true; - } - Ok(len) - } -} - -impl SyncStream { - /// Flush all data in the write buffer. - pub async fn flush_write_buf(&mut self) -> io::Result { - let stream = &mut self.stream; - let len = self.write_buffer.flush_to(stream).await?; - stream.flush().await?; - Ok(len) - } -} +use crate::{PinBoxFuture, compat::SyncStream}; /// A stream wrapper for [`futures_util::io`] traits. pub struct AsyncStream { diff --git a/compio-io/src/compat/mod.rs b/compio-io/src/compat/mod.rs new file mode 100644 index 00000000..fbb0427a --- /dev/null +++ b/compio-io/src/compat/mod.rs @@ -0,0 +1,7 @@ +//! Compat wrappers for interop with other crates. + +mod sync_stream; +pub use sync_stream::*; + +mod async_stream; +pub use async_stream::*; diff --git a/compio-io/src/compat/sync_stream.rs b/compio-io/src/compat/sync_stream.rs new file mode 100644 index 00000000..c70be850 --- /dev/null +++ b/compio-io/src/compat/sync_stream.rs @@ -0,0 +1,301 @@ +use std::{ + io::{self, BufRead, Read, Write}, + mem::MaybeUninit, +}; + +use compio_buf::{BufResult, IntoInner, IoBuf}; + +use crate::{buffer::Buffer, util::DEFAULT_BUF_SIZE}; + +/// A growable buffered stream adapter that bridges async I/O with sync traits. +/// +/// # Buffer Growth Strategy +/// +/// - **Read buffer**: Grows as needed to accommodate incoming data, up to +/// `max_buffer_size` +/// - **Write buffer**: Grows as needed for outgoing data, up to +/// `max_buffer_size` +/// - Both buffers shrink back to `base_capacity` when fully consumed and +/// capacity exceeds 4x base +/// +/// # Usage Pattern +/// +/// The sync `Read` and `Write` implementations will return `WouldBlock` errors +/// when buffers need servicing via the async methods: +/// +/// - Call `fill_read_buf()` when `Read::read()` returns `WouldBlock` +/// - Call `flush_write_buf()` when `Write::write()` returns `WouldBlock` +/// +/// # Note on flush() +/// +/// The `Write::flush()` method intentionally returns `Ok(())` without checking +/// if there's buffered data. This is for compatibility with libraries like +/// tungstenite that call `flush()` after every write. Actual flushing happens +/// via the async `flush_write_buf()` method. +#[derive(Debug)] +pub struct SyncStream { + inner: S, + read_buf: Buffer, + write_buf: Buffer, + eof: bool, + base_capacity: usize, + max_buffer_size: usize, +} + +impl SyncStream { + // 64MB max + const DEFAULT_MAX_BUFFER: usize = 64 * 1024 * 1024; + + /// Creates a new `SyncStream` with default buffer sizes. + /// + /// - Base capacity: 8KB + /// - Max buffer size: 64MB + pub fn new(stream: S) -> Self { + Self::with_capacity(DEFAULT_BUF_SIZE, stream) + } + + /// Creates a new `SyncStream` with a custom base capacity. + /// + /// The maximum buffer size defaults to 64MB. + pub fn with_capacity(base_capacity: usize, stream: S) -> Self { + Self::with_limits(base_capacity, Self::DEFAULT_MAX_BUFFER, stream) + } + + /// Creates a new `SyncStream` with custom base capacity and maximum + /// buffer size. + pub fn with_limits(base_capacity: usize, max_buffer_size: usize, stream: S) -> Self { + Self { + inner: stream, + read_buf: Buffer::with_capacity(base_capacity), + write_buf: Buffer::with_capacity(base_capacity), + eof: false, + base_capacity, + max_buffer_size, + } + } + + /// Returns a reference to the underlying stream. + pub fn get_ref(&self) -> &S { + &self.inner + } + + /// Returns a mutable reference to the underlying stream. + pub fn get_mut(&mut self) -> &mut S { + &mut self.inner + } + + /// Consumes the `SyncStream`, returning the underlying stream. + pub fn into_inner(self) -> S { + self.inner + } + + /// Returns `true` if the stream has reached EOF. + pub fn is_eof(&self) -> bool { + self.eof + } + + /// Returns the available bytes in the read buffer. + fn available_read(&self) -> &[u8] { + self.read_buf.slice() + } + + /// Marks `amt` bytes as consumed from the read buffer. + /// + /// Resets the buffer when all data is consumed and shrinks capacity + /// if it has grown significantly beyond the base capacity. + fn consume_read(&mut self, amt: usize) { + let all_done = self.read_buf.advance(amt); + + // Shrink oversized buffers back to base capacity + if all_done { + self.read_buf + .compact_to(self.base_capacity, self.max_buffer_size); + } + } + + /// Pull some bytes from this source into the specified buffer. + pub fn read_buf_uninit(&mut self, buf: &mut [MaybeUninit]) -> io::Result { + let available = self.fill_buf()?; + + let to_read = available.len().min(buf.len()); + buf[..to_read].copy_from_slice(unsafe { + std::slice::from_raw_parts(available.as_ptr().cast(), to_read) + }); + self.consume(to_read); + + Ok(to_read) + } +} + +impl Read for SyncStream { + /// Reads data from the internal buffer. + /// + /// Returns `WouldBlock` if the buffer is empty and not at EOF, + /// indicating that `fill_read_buf()` should be called. + fn read(&mut self, buf: &mut [u8]) -> io::Result { + let mut slice = self.fill_buf()?; + slice.read(buf).inspect(|res| { + self.consume(*res); + }) + } + + #[cfg(feature = "read_buf")] + fn read_buf(&mut self, mut buf: io::BorrowedCursor<'_>) -> io::Result<()> { + let mut slice = self.fill_buf()?; + let old_written = buf.written(); + slice.read_buf(buf.reborrow())?; + let len = buf.written() - old_written; + self.consume(len); + Ok(()) + } +} + +impl BufRead for SyncStream { + fn fill_buf(&mut self) -> io::Result<&[u8]> { + let available = self.available_read(); + + if available.is_empty() && !self.eof { + return Err(would_block("need to fill read buffer")); + } + + Ok(available) + } + + fn consume(&mut self, amt: usize) { + self.consume_read(amt); + } +} + +impl Write for SyncStream { + /// Writes data to the internal buffer. + /// + /// Returns `WouldBlock` if the buffer needs flushing or has reached max + /// capacity. In the latter case, it may write partial data before + /// returning `WouldBlock`. + fn write(&mut self, buf: &[u8]) -> io::Result { + // Check if we should flush first + if self.write_buf.need_flush() && !self.write_buf.is_empty() { + return Err(would_block("need to flush write buffer")); + } + + let written = self.write_buf.with_sync(|mut inner| { + let res = if inner.buf_len() + buf.len() > self.max_buffer_size { + let space = self.max_buffer_size - inner.buf_len(); + if space == 0 { + Err(would_block("write buffer full, need to flush")) + } else { + inner.extend_from_slice(&buf[..space]); + Ok(space) + } + } else { + inner.extend_from_slice(buf); + Ok(buf.len()) + }; + BufResult(res, inner) + })?; + + Ok(written) + } + + /// Returns `Ok(())` without checking for buffered data. + /// + /// **Important**: This does NOT actually flush data to the underlying + /// stream. This behavior is intentional for compatibility with + /// libraries like tungstenite that call `flush()` after every write + /// operation. The actual async flush happens when `flush_write_buf()` + /// is called. + /// + /// This prevents spurious errors in sync code that expects `flush()` to + /// succeed after successfully buffering data. + fn flush(&mut self) -> io::Result<()> { + Ok(()) + } +} + +fn would_block(msg: &str) -> io::Error { + io::Error::new(io::ErrorKind::WouldBlock, msg) +} + +impl SyncStream { + /// Fills the read buffer by reading from the underlying async stream. + /// + /// This method: + /// 1. Compacts the buffer if there's unconsumed data + /// 2. Ensures there's space for at least `base_capacity` more bytes + /// 3. Reads data from the underlying stream + /// 4. Returns the number of bytes read (0 indicates EOF) + /// + /// # Errors + /// + /// Returns an error if: + /// - The read buffer has reached `max_buffer_size` + /// - The underlying stream returns an error + pub async fn fill_read_buf(&mut self) -> io::Result { + if self.eof { + return Ok(0); + } + + // Compact buffer, move unconsumed data to the front + self.read_buf + .compact_to(self.base_capacity, self.max_buffer_size); + + let read = self + .read_buf + .with(|mut inner| async { + let current_len = inner.buf_len(); + + if current_len >= self.max_buffer_size { + return BufResult( + Err(io::Error::new( + io::ErrorKind::OutOfMemory, + format!("read buffer size limit ({}) exceeded", self.max_buffer_size), + )), + inner, + ); + } + + let capacity = inner.buf_capacity(); + let available_space = capacity - current_len; + + // If target space is less than base capacity, grow the buffer. + let target_space = self.base_capacity; + if available_space < target_space { + let new_capacity = current_len + target_space; + inner.reserve_exact(new_capacity - capacity); + } + + let len = inner.buf_len(); + let read_slice = inner.slice(len..); + self.inner.read(read_slice).await.into_inner() + }) + .await?; + if read == 0 { + self.eof = true; + } + Ok(read) + } +} + +impl SyncStream { + /// Flushes the write buffer to the underlying async stream. + /// + /// This method: + /// 1. Writes all buffered data to the underlying stream + /// 2. Calls `flush()` on the underlying stream + /// 3. Returns the total number of bytes flushed + /// + /// On error, any unwritten data remains in the buffer and can be retried. + /// + /// # Errors + /// + /// Returns an error if the underlying stream returns an error. + /// In this case, the buffer retains any data that wasn't successfully + /// written. + pub async fn flush_write_buf(&mut self) -> io::Result { + let flushed = self.write_buf.flush_to(&mut self.inner).await?; + self.write_buf + .compact_to(self.base_capacity, self.max_buffer_size); + self.inner.flush().await?; + Ok(flushed) + } +} diff --git a/compio-io/tests/compat.rs b/compio-io/tests/compat.rs index c8b8520d..f7cb7750 100644 --- a/compio-io/tests/compat.rs +++ b/compio-io/tests/compat.rs @@ -60,10 +60,10 @@ fn async_compat_write() { .write(&[1, 1, 4, 5, 1, 4, 1, 9, 1, 9, 8, 1, 0]) .await .unwrap(); - assert_eq!(len, 10); + assert_eq!(len, 13); - stream.flush().await.unwrap(); - assert_eq!(stream.get_ref().get_ref(), &[1, 1, 4, 5, 1, 4, 1, 9, 1, 9]); + let err = stream.flush().await.unwrap_err(); + assert_eq!(err.kind(), std::io::ErrorKind::WriteZero); }) } @@ -78,6 +78,6 @@ fn async_compat_flush_fail() { .unwrap(); assert_eq!(len, 13); let err = stream.flush().await.unwrap_err(); - assert_eq!(err.kind(), std::io::ErrorKind::UnexpectedEof); + assert_eq!(err.kind(), std::io::ErrorKind::WriteZero); }) } diff --git a/compio-ws/Cargo.toml b/compio-ws/Cargo.toml index a268b49f..3a01061f 100644 --- a/compio-ws/Cargo.toml +++ b/compio-ws/Cargo.toml @@ -13,7 +13,7 @@ repository = { workspace = true } rustls = { workspace = true, optional = true, default-features = false } rustls-platform-verifier = { version = "0.6.0", optional = true } tungstenite = "0.28.0" -compio-io = { workspace = true } +compio-io = { workspace = true, features = ["compat"] } compio-net = { workspace = true, optional = true } compio-tls = { workspace = true, optional = true, default-features = false, features = [ "rustls", diff --git a/compio-ws/src/growable_sync_stream.rs b/compio-ws/src/growable_sync_stream.rs deleted file mode 100644 index ec7dcfa8..00000000 --- a/compio-ws/src/growable_sync_stream.rs +++ /dev/null @@ -1,359 +0,0 @@ -use std::io::{self, Read, Write}; - -use compio_buf::{BufResult, IntoInner, IoBuf}; -use compio_io::{AsyncRead, AsyncWrite}; - -/// A growable buffered stream adapter that bridges async I/O with sync traits. -/// -/// This is similar to `compio_io::compat::SyncStream` but with dynamically -/// growing buffers that can expand beyond the initial capacity up to a -/// configurable maximum. -/// -/// # Buffer Growth Strategy -/// -/// - **Read buffer**: Grows as needed to accommodate incoming data, up to -/// `max_buffer_size` -/// - **Write buffer**: Grows as needed for outgoing data, up to -/// `max_buffer_size` -/// - Both buffers shrink back to `base_capacity` when fully consumed and -/// capacity exceeds 4x base -/// -/// # Usage Pattern -/// -/// The sync `Read` and `Write` implementations will return `WouldBlock` errors -/// when buffers need servicing via the async methods: -/// -/// - Call `fill_read_buf()` when `Read::read()` returns `WouldBlock` -/// - Call `flush_write_buf()` when `Write::write()` returns `WouldBlock` -/// -/// # Note on flush() -/// -/// The `Write::flush()` method intentionally returns `Ok(())` without checking -/// if there's buffered data. This is for compatibility with libraries like -/// tungstenite that call `flush()` after every write. Actual flushing happens -/// via the async `flush_write_buf()` method. -#[derive(Debug)] -pub struct GrowableSyncStream { - inner: S, - read_buf: Vec, - read_pos: usize, - write_buf: Vec, - eof: bool, - base_capacity: usize, - max_buffer_size: usize, -} - -impl GrowableSyncStream { - const DEFAULT_BASE_CAPACITY: usize = 8 * 1024; - // 8KB base - const DEFAULT_MAX_BUFFER: usize = 64 * 1024 * 1024; - - // 64MB max - - /// Creates a new `GrowableSyncStream` with default buffer sizes. - /// - /// - Base capacity: 8KB - /// - Max buffer size: 64MB - pub fn new(stream: S) -> Self { - Self::with_capacity(Self::DEFAULT_BASE_CAPACITY, stream) - } - - /// Creates a new `GrowableSyncStream` with a custom base capacity. - /// - /// The maximum buffer size defaults to 64MB. - pub fn with_capacity(base_capacity: usize, stream: S) -> Self { - Self { - inner: stream, - read_buf: Vec::with_capacity(base_capacity), - read_pos: 0, - write_buf: Vec::with_capacity(base_capacity), - eof: false, - base_capacity, - max_buffer_size: Self::DEFAULT_MAX_BUFFER, - } - } - - /// Creates a new `GrowableSyncStream` with custom base capacity and maximum - /// buffer size. - pub fn with_limits(base_capacity: usize, max_buffer_size: usize, stream: S) -> Self { - Self { - inner: stream, - read_buf: Vec::with_capacity(base_capacity), - read_pos: 0, - write_buf: Vec::with_capacity(base_capacity), - eof: false, - base_capacity, - max_buffer_size, - } - } - - /// Returns a reference to the underlying stream. - pub fn get_ref(&self) -> &S { - &self.inner - } - - /// Returns a mutable reference to the underlying stream. - pub fn get_mut(&mut self) -> &mut S { - &mut self.inner - } - - /// Consumes the `GrowableSyncStream`, returning the underlying stream. - pub fn into_inner(self) -> S { - self.inner - } - - /// Returns `true` if the stream has reached EOF. - pub fn is_eof(&self) -> bool { - self.eof - } - - /// Returns the available bytes in the read buffer. - fn available_read(&self) -> &[u8] { - &self.read_buf[self.read_pos..] - } - - /// Marks `amt` bytes as consumed from the read buffer. - /// - /// Resets the buffer when all data is consumed and shrinks capacity - /// if it has grown significantly beyond the base capacity. - fn consume_read(&mut self, amt: usize) { - self.read_pos += amt; - - // Shrink oversized buffers back to base capacity - if self.read_pos >= self.read_buf.len() { - self.read_pos = 0; - - if self.read_buf.capacity() > self.base_capacity * 4 { - self.read_buf = Vec::with_capacity(self.base_capacity); - } else { - self.read_buf.clear(); - } - } - } -} - -impl Read for GrowableSyncStream { - /// Reads data from the internal buffer. - /// - /// Returns `WouldBlock` if the buffer is empty and not at EOF, - /// indicating that `fill_read_buf()` should be called. - fn read(&mut self, buf: &mut [u8]) -> io::Result { - let available = self.available_read(); - - if available.is_empty() && !self.eof { - return Err(io::Error::new( - io::ErrorKind::WouldBlock, - "need to fill read buffer", - )); - } - - let to_read = available.len().min(buf.len()); - buf[..to_read].copy_from_slice(&available[..to_read]); - self.consume_read(to_read); - - Ok(to_read) - } -} - -impl Write for GrowableSyncStream { - /// Writes data to the internal buffer. - /// - /// Returns `WouldBlock` if the buffer needs flushing or has reached max - /// capacity. In the latter case, it may write partial data before - /// returning `WouldBlock`. - fn write(&mut self, buf: &[u8]) -> io::Result { - // Check if we should flush first - if self.write_buf.len() > self.base_capacity * 2 / 3 && !self.write_buf.is_empty() { - return Err(io::Error::new( - io::ErrorKind::WouldBlock, - "need to flush write buffer", - )); - } - - // Check if write would exceed max buffer size - if self.write_buf.len() + buf.len() > self.max_buffer_size { - let space = self.max_buffer_size - self.write_buf.len(); - if space == 0 { - return Err(io::Error::new( - io::ErrorKind::WouldBlock, - "write buffer full, need to flush", - )); - } - self.write_buf.extend_from_slice(&buf[..space]); - return Ok(space); - } - - self.write_buf.extend_from_slice(buf); - Ok(buf.len()) - } - - /// Returns `Ok(())` without checking for buffered data. - /// - /// **Important**: This does NOT actually flush data to the underlying - /// stream. This behavior is intentional for compatibility with - /// libraries like tungstenite that call `flush()` after every write - /// operation. The actual async flush happens when `flush_write_buf()` - /// is called. - /// - /// This prevents spurious errors in sync code that expects `flush()` to - /// succeed after successfully buffering data. - fn flush(&mut self) -> io::Result<()> { - Ok(()) - } -} - -impl GrowableSyncStream { - /// Fills the read buffer by reading from the underlying async stream. - /// - /// This method: - /// 1. Compacts the buffer if there's unconsumed data - /// 2. Ensures there's space for at least `base_capacity` more bytes - /// 3. Reads data from the underlying stream - /// 4. Returns the number of bytes read (0 indicates EOF) - /// - /// # Errors - /// - /// Returns an error if: - /// - The read buffer has reached `max_buffer_size` - /// - The underlying stream returns an error - pub async fn fill_read_buf(&mut self) -> io::Result { - if self.eof { - return Ok(0); - } - - // Compact buffer, move unconsumed data to the front - if self.read_pos > 0 && self.read_pos < self.read_buf.len() { - let buf_len = self.read_buf.len(); - let remaining = buf_len - self.read_pos; - self.read_buf.copy_within(self.read_pos..buf_len, 0); - - // SAFETY: We're setting the length to the amount of data we just moved. - // The data from 0..remaining is initialized (just moved from read_pos..buf_len) - unsafe { - self.read_buf.set_len(remaining); - } - self.read_pos = 0; - } else if self.read_pos >= self.read_buf.len() { - // All data consumed, reset buffer - self.read_pos = 0; - if self.read_buf.capacity() > self.base_capacity * 4 { - self.read_buf = Vec::with_capacity(self.base_capacity); - } else { - self.read_buf.clear(); - } - } - - let current_len = self.read_buf.len(); - - if current_len >= self.max_buffer_size { - return Err(io::Error::new( - io::ErrorKind::OutOfMemory, - format!("read buffer size limit ({}) exceeded", self.max_buffer_size), - )); - } - - let capacity = self.read_buf.capacity(); - let available_space = capacity - current_len; - - let target_space = self.base_capacity; - if available_space < target_space { - let new_capacity = current_len + target_space; - self.read_buf.reserve_exact(new_capacity - capacity); - } - - let capacity = self.read_buf.capacity(); - let len = self.read_buf.len(); - - // SAFETY: We're extending the buffer to its capacity to allow reading into - // uninitialized memory. This is safe because: - // 1. We save the original length and restore it on error - // 2. The async read operation initializes the bytes it writes to - // 3. We update the length based on how many bytes were actually read - unsafe { - self.read_buf.set_len(capacity); - } - - let buf = std::mem::take(&mut self.read_buf); - - let read_slice = IoBuf::slice(buf, len..); - - let BufResult(result, mut buf) = self.inner.read(read_slice).await.into_inner(); - - match result { - Ok(n) => { - if n == 0 { - self.eof = true; - unsafe { - buf.set_len(len); - } - } else { - unsafe { - buf.set_len(len + n); - } - } - self.read_buf = buf; - Ok(n) - } - Err(e) => { - unsafe { - buf.set_len(len); - } - self.read_buf = buf; - Err(e) - } - } - } -} - -impl GrowableSyncStream { - /// Flushes the write buffer to the underlying async stream. - /// - /// This method: - /// 1. Writes all buffered data to the underlying stream - /// 2. Calls `flush()` on the underlying stream - /// 3. Returns the total number of bytes flushed - /// - /// On error, any unwritten data remains in the buffer and can be retried. - /// - /// # Errors - /// - /// Returns an error if the underlying stream returns an error. - /// In this case, the buffer retains any data that wasn't successfully - /// written. - pub async fn flush_write_buf(&mut self) -> io::Result { - if self.write_buf.is_empty() { - return Ok(0); - } - - let total = self.write_buf.len(); - let mut buf = std::mem::take(&mut self.write_buf); - let mut flushed = 0; - - while flushed < total { - let write_slice = IoBuf::slice(buf, flushed..); - - let BufResult(result, returned_buf) = self.inner.write(write_slice).await.into_inner(); - buf = returned_buf; - - match result { - Ok(0) => { - self.write_buf = buf[flushed..].to_vec(); - return Err(io::Error::new(io::ErrorKind::WriteZero, "write returned 0")); - } - Ok(n) => { - flushed += n; - } - Err(e) => { - self.write_buf = buf[flushed..].to_vec(); - return Err(e); - } - } - } - - self.write_buf = Vec::with_capacity(self.base_capacity); - - self.inner.flush().await?; - - Ok(flushed) - } -} diff --git a/compio-ws/src/lib.rs b/compio-ws/src/lib.rs index 03d31602..761c57a2 100644 --- a/compio-ws/src/lib.rs +++ b/compio-ws/src/lib.rs @@ -7,7 +7,6 @@ //! //! Each WebSocket stream implements message reading and writing. -pub mod growable_sync_stream; pub mod stream; #[cfg(feature = "rustls")] @@ -15,8 +14,7 @@ pub mod rustls; use std::io::ErrorKind; -use compio_io::{AsyncRead, AsyncWrite}; -use growable_sync_stream::GrowableSyncStream; +use compio_io::{AsyncRead, AsyncWrite, compat::SyncStream}; use tungstenite::{ Error as WsError, HandshakeError, Message, WebSocket, client::IntoClientRequest, @@ -37,7 +35,7 @@ pub use crate::rustls::{ }; pub struct WebSocketStream { - inner: WebSocket>, + inner: WebSocket>, } impl WebSocketStream @@ -113,7 +111,7 @@ where self.inner.get_mut().get_mut() } - pub fn get_inner(self) -> WebSocket> { + pub fn get_inner(self) -> WebSocket> { self.inner } } @@ -171,7 +169,7 @@ where S: AsyncRead + AsyncWrite + Unpin + std::fmt::Debug, C: Callback, { - let sync_stream = GrowableSyncStream::new(stream); + let sync_stream = SyncStream::new(stream); let mut handshake_result = tungstenite::accept_hdr_with_config(sync_stream, callback, config); loop { @@ -235,7 +233,7 @@ where R: IntoClientRequest, S: AsyncRead + AsyncWrite + Unpin, { - let sync_stream = GrowableSyncStream::new(stream); + let sync_stream = SyncStream::new(stream); let mut handshake_result = tungstenite::client::client_with_config(request, sync_stream, config);