diff --git a/Cargo.toml b/Cargo.toml index 69e1c1c1..55cafc20 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -30,12 +30,12 @@ compio-driver = { path = "./compio-driver", version = "0.9.0", default-features compio-runtime = { path = "./compio-runtime", version = "0.9.0" } compio-macros = { path = "./compio-macros", version = "0.1.2" } compio-fs = { path = "./compio-fs", version = "0.9.0" } -compio-io = { path = "./compio-io", version = "0.8.0" } +compio-io = { path = "./compio-io", version = "0.8.2" } compio-net = { path = "./compio-net", version = "0.9.0" } compio-signal = { path = "./compio-signal", version = "0.7.0" } compio-dispatcher = { path = "./compio-dispatcher", version = "0.8.0" } compio-log = { path = "./compio-log", version = "0.1.0" } -compio-tls = { path = "./compio-tls", version = "0.7.0", default-features = false } +compio-tls = { path = "./compio-tls", version = "0.7.1", default-features = false } compio-process = { path = "./compio-process", version = "0.6.0" } compio-quic = { path = "./compio-quic", version = "0.5.0", default-features = false } @@ -46,6 +46,7 @@ criterion = "0.7.0" crossbeam-queue = "0.3.8" flume = { version = "0.11.0", default-features = false } futures-channel = "0.3.29" +futures-rustls = { version = "0.26.0", default-features = false } futures-util = "0.3.29" libc = "0.2.164" nix = "0.30.1" diff --git a/compio-io/Cargo.toml b/compio-io/Cargo.toml index 65631312..398006f6 100644 --- a/compio-io/Cargo.toml +++ b/compio-io/Cargo.toml @@ -1,6 +1,6 @@ [package] name = "compio-io" -version = "0.8.1" +version = "0.8.2" description = "IO traits for completion based async IO" categories = ["asynchronous"] keywords = ["async", "io"] @@ -15,7 +15,6 @@ compio-buf = { workspace = true, features = ["arrayvec", "bytes"] } futures-util = { workspace = true, features = ["sink"] } paste = { workspace = true } thiserror = { workspace = true, optional = true } -pin-project-lite = { workspace = true, optional = true } serde = { version = "1.0.219", optional = true } serde_json = { version = "1.0.140", optional = true } @@ -29,7 +28,7 @@ futures-executor = "0.3.30" [features] default = [] -compat = ["dep:pin-project-lite", "futures-util/io"] +compat = ["futures-util/io"] # Codecs # Serde json codec diff --git a/compio-io/src/compat.rs b/compio-io/src/compat.rs index 84310d32..d87b7768 100644 --- a/compio-io/src/compat.rs +++ b/compio-io/src/compat.rs @@ -1,6 +1,7 @@ //! Compat wrappers for interop with other crates. use std::{ + fmt::Debug, io::{self, BufRead, Read, Write}, mem::MaybeUninit, pin::Pin, @@ -8,7 +9,6 @@ use std::{ }; use compio_buf::{BufResult, IntoInner, IoBuf, IoBufMut, SetBufInit}; -use pin_project_lite::pin_project; use crate::{PinBoxFuture, buffer::Buffer, util::DEFAULT_BUF_SIZE}; @@ -176,15 +176,14 @@ impl SyncStream { } } -pin_project! { - /// A stream wrapper for [`futures_util::io`] traits. - pub struct AsyncStream { - #[pin] - inner: SyncStream, - read_future: Option>>, - write_future: Option>>, - shutdown_future: Option>>, - } +/// A stream wrapper for [`futures_util::io`] traits. +pub struct AsyncStream { + // The futures keep the reference to the inner stream, so we need to pin + // the inner stream to make sure the reference is valid. + inner: Pin>>, + read_future: Option>>, + write_future: Option>>, + shutdown_future: Option>>, } impl AsyncStream { @@ -200,7 +199,7 @@ impl AsyncStream { fn new_impl(inner: SyncStream) -> Self { Self { - inner, + inner: Box::pin(inner), read_future: None, write_future: None, shutdown_future: None, @@ -253,20 +252,18 @@ macro_rules! poll_future_would_block { impl futures_util::AsyncRead for AsyncStream { fn poll_read( - self: Pin<&mut Self>, + mut self: Pin<&mut Self>, cx: &mut Context<'_>, buf: &mut [u8], ) -> Poll> { - let this = self.project(); // Safety: // - The futures won't live longer than the stream. - // - `self` is pinned. - // - The inner stream won't be moved. + // - The inner stream is pinned. let inner: &'static mut SyncStream = - unsafe { &mut *(this.inner.get_unchecked_mut() as *mut _) }; + unsafe { &mut *(self.inner.as_mut().get_unchecked_mut() as *mut _) }; poll_future_would_block!( - this.read_future, + self.read_future, cx, inner.fill_read_buf(), io::Read::read(inner, buf) @@ -279,16 +276,14 @@ impl AsyncStream { /// /// On success, returns `Poll::Ready(Ok(num_bytes_read))`. pub fn poll_read_uninit( - self: Pin<&mut Self>, + mut self: Pin<&mut Self>, cx: &mut Context<'_>, buf: &mut [MaybeUninit], ) -> Poll> { - let this = self.project(); - let inner: &'static mut SyncStream = - unsafe { &mut *(this.inner.get_unchecked_mut() as *mut _) }; + unsafe { &mut *(self.inner.as_mut().get_unchecked_mut() as *mut _) }; poll_future_would_block!( - this.read_future, + self.read_future, cx, inner.fill_read_buf(), inner.read_buf_uninit(buf) @@ -297,13 +292,11 @@ impl AsyncStream { } impl futures_util::AsyncBufRead for AsyncStream { - fn poll_fill_buf(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { - let this = self.project(); - + fn poll_fill_buf(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { let inner: &'static mut SyncStream = - unsafe { &mut *(this.inner.get_unchecked_mut() as *mut _) }; + unsafe { &mut *(self.inner.as_mut().get_unchecked_mut() as *mut _) }; poll_future_would_block!( - this.read_future, + self.read_future, cx, inner.fill_read_buf(), // Safety: anyway the slice won't be used after free. @@ -311,65 +304,63 @@ impl futures_util::AsyncBufRead for AsyncStream, amt: usize) { - let this = self.project(); - - let inner: &'static mut SyncStream = - unsafe { &mut *(this.inner.get_unchecked_mut() as *mut _) }; - inner.consume(amt) + fn consume(mut self: Pin<&mut Self>, amt: usize) { + unsafe { self.inner.as_mut().get_unchecked_mut().consume(amt) } } } impl futures_util::AsyncWrite for AsyncStream { fn poll_write( - self: Pin<&mut Self>, + mut self: Pin<&mut Self>, cx: &mut Context<'_>, buf: &[u8], ) -> Poll> { - let this = self.project(); - - if this.shutdown_future.is_some() { - debug_assert!(this.write_future.is_none()); + if self.shutdown_future.is_some() { + debug_assert!(self.write_future.is_none()); return Poll::Pending; } let inner: &'static mut SyncStream = - unsafe { &mut *(this.inner.get_unchecked_mut() as *mut _) }; + unsafe { &mut *(self.inner.as_mut().get_unchecked_mut() as *mut _) }; poll_future_would_block!( - this.write_future, + self.write_future, cx, inner.flush_write_buf(), io::Write::write(inner, buf) ) } - fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { - let this = self.project(); - - if this.shutdown_future.is_some() { - debug_assert!(this.write_future.is_none()); + fn poll_flush(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { + if self.shutdown_future.is_some() { + debug_assert!(self.write_future.is_none()); return Poll::Pending; } let inner: &'static mut SyncStream = - unsafe { &mut *(this.inner.get_unchecked_mut() as *mut _) }; - let res = poll_future!(this.write_future, cx, inner.flush_write_buf()); + unsafe { &mut *(self.inner.as_mut().get_unchecked_mut() as *mut _) }; + let res = poll_future!(self.write_future, cx, inner.flush_write_buf()); Poll::Ready(res.map(|_| ())) } - fn poll_close(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { - let this = self.project(); - + fn poll_close(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { // Avoid shutdown on flush because the inner buffer might be passed to the // driver. - if this.write_future.is_some() { - debug_assert!(this.shutdown_future.is_none()); + if self.write_future.is_some() { + debug_assert!(self.shutdown_future.is_none()); return Poll::Pending; } let inner: &'static mut SyncStream = - unsafe { &mut *(this.inner.get_unchecked_mut() as *mut _) }; - let res = poll_future!(this.shutdown_future, cx, inner.get_mut().shutdown()); + unsafe { &mut *(self.inner.as_mut().get_unchecked_mut() as *mut _) }; + let res = poll_future!(self.shutdown_future, cx, inner.get_mut().shutdown()); Poll::Ready(res) } } + +impl Debug for AsyncStream { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + f.debug_struct("AsyncStream") + .field("inner", &self.inner) + .finish_non_exhaustive() + } +} diff --git a/compio-tls/Cargo.toml b/compio-tls/Cargo.toml index 31a5755f..74be9bb5 100644 --- a/compio-tls/Cargo.toml +++ b/compio-tls/Cargo.toml @@ -1,6 +1,6 @@ [package] name = "compio-tls" -version = "0.7.0" +version = "0.7.1" description = "TLS adaptor with compio" categories = ["asynchronous", "network-programming"] keywords = ["async", "net", "tls"] @@ -25,6 +25,12 @@ rustls = { workspace = true, default-features = false, optional = true, features "tls12", ] } +futures-rustls = { workspace = true, default-features = false, optional = true, features = [ + "logging", + "tls12", +] } +futures-util = { workspace = true, optional = true } + [dev-dependencies] compio-net = { workspace = true } compio-runtime = { workspace = true } @@ -33,14 +39,18 @@ compio-macros = { workspace = true } rustls = { workspace = true, default-features = false, features = ["ring"] } rustls-native-certs = { workspace = true } +futures-rustls = { workspace = true, default-features = false, features = [ + "ring", +] } + [features] default = ["native-tls"] all = ["native-tls", "rustls"] -rustls = ["dep:rustls"] +rustls = ["dep:rustls", "dep:futures-rustls", "dep:futures-util"] -ring = ["rustls", "rustls/ring"] -aws-lc-rs = ["rustls", "rustls/aws-lc-rs"] -aws-lc-rs-fips = ["aws-lc-rs", "rustls/fips"] +ring = ["rustls", "rustls/ring", "futures-rustls/ring"] +aws-lc-rs = ["rustls", "rustls/aws-lc-rs", "futures-rustls/aws-lc-rs"] +aws-lc-rs-fips = ["aws-lc-rs", "rustls/fips", "futures-rustls/fips"] read_buf = ["compio-buf/read_buf", "compio-io/read_buf", "rustls?/read_buf"] nightly = ["read_buf"] diff --git a/compio-tls/src/adapter/mod.rs b/compio-tls/src/adapter.rs similarity index 73% rename from compio-tls/src/adapter/mod.rs rename to compio-tls/src/adapter.rs index 25922202..cee4b5fa 100644 --- a/compio-tls/src/adapter/mod.rs +++ b/compio-tls/src/adapter.rs @@ -1,18 +1,29 @@ -use std::io; +use std::{fmt::Debug, io}; -use compio_io::{AsyncRead, AsyncWrite, compat::SyncStream}; +use compio_io::{ + AsyncRead, AsyncWrite, + compat::{AsyncStream, SyncStream}, +}; use crate::TlsStream; -#[cfg(feature = "rustls")] -mod rtls; - -#[derive(Debug, Clone)] +#[derive(Clone)] enum TlsConnectorInner { #[cfg(feature = "native-tls")] NativeTls(native_tls::TlsConnector), #[cfg(feature = "rustls")] - Rustls(rtls::TlsConnector), + Rustls(futures_rustls::TlsConnector), +} + +impl Debug for TlsConnectorInner { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + match self { + #[cfg(feature = "native-tls")] + Self::NativeTls(_) => f.debug_tuple("NativeTls").finish(), + #[cfg(feature = "rustls")] + Self::Rustls(_) => f.debug_tuple("Rustls").finish(), + } + } } /// A wrapper around a [`native_tls::TlsConnector`] or [`rustls::ClientConfig`], @@ -30,7 +41,7 @@ impl From for TlsConnector { #[cfg(feature = "rustls")] impl From> for TlsConnector { fn from(value: std::sync::Arc) -> Self { - Self(TlsConnectorInner::Rustls(rtls::TlsConnector(value))) + Self(TlsConnectorInner::Rustls(value.into())) } } @@ -47,7 +58,7 @@ impl TlsConnector { /// example, a TCP connection to a remote server. That stream is then /// provided here to perform the client half of a connection to a /// TLS-powered server. - pub async fn connect( + pub async fn connect( &self, domain: &str, stream: S, @@ -58,7 +69,15 @@ impl TlsConnector { handshake_native_tls(c.connect(domain, SyncStream::new(stream))).await } #[cfg(feature = "rustls")] - TlsConnectorInner::Rustls(c) => handshake_rustls(c.connect(domain, stream)).await, + TlsConnectorInner::Rustls(c) => { + let client = c + .connect( + domain.to_string().try_into().map_err(io::Error::other)?, + AsyncStream::new(stream), + ) + .await?; + Ok(TlsStream::from(client)) + } } } } @@ -68,7 +87,7 @@ enum TlsAcceptorInner { #[cfg(feature = "native-tls")] NativeTls(native_tls::TlsAcceptor), #[cfg(feature = "rustls")] - Rustls(rtls::TlsAcceptor), + Rustls(futures_rustls::TlsAcceptor), } /// A wrapper around a [`native_tls::TlsAcceptor`] or [`rustls::ServerConfig`], @@ -86,7 +105,7 @@ impl From for TlsAcceptor { #[cfg(feature = "rustls")] impl From> for TlsAcceptor { fn from(value: std::sync::Arc) -> Self { - Self(TlsAcceptorInner::Rustls(rtls::TlsAcceptor(value))) + Self(TlsAcceptorInner::Rustls(value.into())) } } @@ -101,14 +120,20 @@ impl TlsAcceptor { /// This is typically used after a new socket has been accepted from a /// `TcpListener`. That socket is then passed to this function to perform /// the server half of accepting a client connection. - pub async fn accept(&self, stream: S) -> io::Result> { + pub async fn accept( + &self, + stream: S, + ) -> io::Result> { match &self.0 { #[cfg(feature = "native-tls")] TlsAcceptorInner::NativeTls(c) => { handshake_native_tls(c.accept(SyncStream::new(stream))).await } #[cfg(feature = "rustls")] - TlsAcceptorInner::Rustls(c) => handshake_rustls(c.accept(stream)).await, + TlsAcceptorInner::Rustls(c) => { + let server = c.accept(AsyncStream::new(stream)).await?; + Ok(TlsStream::from(server)) + } } } } @@ -140,32 +165,3 @@ async fn handshake_native_tls( } } } - -#[cfg(feature = "rustls")] -async fn handshake_rustls( - mut res: Result, rtls::HandshakeError>, -) -> io::Result> -where - C: std::ops::DerefMut>, -{ - use rtls::HandshakeError; - - loop { - match res { - Ok(mut s) => { - s.flush().await?; - return Ok(s); - } - Err(e) => match e { - HandshakeError::Rustls(e) => return Err(io::Error::other(e)), - HandshakeError::System(e) => return Err(e), - HandshakeError::WouldBlock(mut mid_stream) => { - if mid_stream.get_mut().flush_write_buf().await? == 0 { - mid_stream.get_mut().fill_read_buf().await?; - } - res = mid_stream.handshake::(); - } - }, - } - } -} diff --git a/compio-tls/src/adapter/rtls.rs b/compio-tls/src/adapter/rtls.rs deleted file mode 100644 index ab5fd453..00000000 --- a/compio-tls/src/adapter/rtls.rs +++ /dev/null @@ -1,137 +0,0 @@ -use std::{io, ops::DerefMut, sync::Arc}; - -use compio_io::{AsyncRead, AsyncWrite, compat::SyncStream}; -use rustls::{ - ClientConfig, ClientConnection, ConnectionCommon, Error, ServerConfig, ServerConnection, - pki_types::ServerName, -}; - -use crate::TlsStream; - -pub enum HandshakeError { - Rustls(Error), - System(io::Error), - WouldBlock(MidStream), -} - -pub struct MidStream { - stream: SyncStream, - conn: C, - result_fn: fn(SyncStream, C) -> TlsStream, -} - -impl MidStream { - pub fn new( - stream: SyncStream, - conn: C, - result_fn: fn(SyncStream, C) -> TlsStream, - ) -> Self { - Self { - stream, - conn, - result_fn, - } - } - - pub fn get_mut(&mut self) -> &mut SyncStream { - &mut self.stream - } - - pub fn handshake(mut self) -> Result, HandshakeError> - where - C: DerefMut>, - S: AsyncRead + AsyncWrite, - { - loop { - let mut write_would_block = false; - let mut read_would_block = false; - - while self.conn.wants_write() { - match self.conn.write_tls(&mut self.stream) { - Ok(_) => { - write_would_block = true; - } - Err(e) if e.kind() == io::ErrorKind::WouldBlock => { - write_would_block = true; - break; - } - Err(e) => return Err(HandshakeError::System(e)), - } - } - - while !self.stream.is_eof() && self.conn.wants_read() { - match self.conn.read_tls(&mut self.stream) { - Ok(_) => { - self.conn - .process_new_packets() - .map_err(HandshakeError::Rustls)?; - } - Err(e) if e.kind() == io::ErrorKind::WouldBlock => { - read_would_block = true; - break; - } - Err(e) => return Err(HandshakeError::System(e)), - } - } - - return match (self.stream.is_eof(), self.conn.is_handshaking()) { - (true, true) => { - let err = io::Error::new(io::ErrorKind::UnexpectedEof, "tls handshake eof"); - Err(HandshakeError::System(err)) - } - (_, false) => Ok((self.result_fn)(self.stream, self.conn)), - (_, true) if write_would_block || read_would_block => { - Err(HandshakeError::WouldBlock(self)) - } - _ => continue, - }; - } - } -} - -#[derive(Debug, Clone)] -pub struct TlsConnector(pub Arc); - -impl TlsConnector { - #[allow(clippy::result_large_err)] - pub fn connect( - &self, - domain: &str, - stream: S, - ) -> Result, HandshakeError> { - let conn = ClientConnection::new( - self.0.clone(), - ServerName::try_from(domain) - .map_err(|e| HandshakeError::System(io::Error::other(e)))? - .to_owned(), - ) - .map_err(HandshakeError::Rustls)?; - - MidStream::new( - SyncStream::new(stream), - conn, - TlsStream::::new_rustls_client, - ) - .handshake() - } -} - -#[derive(Debug, Clone)] -pub struct TlsAcceptor(pub Arc); - -impl TlsAcceptor { - #[allow(clippy::result_large_err)] - pub fn accept( - &self, - stream: S, - ) -> Result, HandshakeError> { - let conn = ServerConnection::new(self.0.clone()).map_err(HandshakeError::Rustls)?; - - MidStream::new( - SyncStream::new(stream), - conn, - TlsStream::::new_rustls_server, - ) - .handshake() - } -} diff --git a/compio-tls/src/stream.rs b/compio-tls/src/stream.rs new file mode 100644 index 00000000..5760d605 --- /dev/null +++ b/compio-tls/src/stream.rs @@ -0,0 +1,173 @@ +use std::{borrow::Cow, io, mem::MaybeUninit}; + +use compio_buf::{BufResult, IoBuf, IoBufMut}; +use compio_io::{ + AsyncRead, AsyncWrite, + compat::{AsyncStream, SyncStream}, +}; + +#[derive(Debug)] +#[allow(clippy::large_enum_variant)] +enum TlsStreamInner { + #[cfg(feature = "native-tls")] + NativeTls(native_tls::TlsStream>), + #[cfg(feature = "rustls")] + Rustls(futures_rustls::TlsStream>), +} + +impl TlsStreamInner { + pub fn negotiated_alpn(&self) -> Option> { + match self { + #[cfg(feature = "native-tls")] + Self::NativeTls(s) => s.negotiated_alpn().ok().flatten().map(Cow::from), + #[cfg(feature = "rustls")] + Self::Rustls(s) => s.get_ref().1.alpn_protocol().map(Cow::from), + } + } +} + +/// A wrapper around an underlying raw stream which implements the TLS or SSL +/// protocol. +/// +/// A `TlsStream` represents a handshake that has been completed successfully +/// and both the server and the client are ready for receiving and sending +/// data. Bytes read from a `TlsStream` are decrypted from `S` and bytes written +/// to a `TlsStream` are encrypted when passing through to `S`. +#[derive(Debug)] +pub struct TlsStream(TlsStreamInner); + +impl TlsStream { + /// Returns the negotiated ALPN protocol. + pub fn negotiated_alpn(&self) -> Option> { + self.0.negotiated_alpn() + } +} + +#[cfg(feature = "native-tls")] +#[doc(hidden)] +impl From>> for TlsStream { + fn from(value: native_tls::TlsStream>) -> Self { + Self(TlsStreamInner::NativeTls(value)) + } +} + +#[cfg(feature = "rustls")] +#[doc(hidden)] +impl From>> for TlsStream { + fn from(value: futures_rustls::client::TlsStream>) -> Self { + Self(TlsStreamInner::Rustls(futures_rustls::TlsStream::Client( + value, + ))) + } +} + +#[cfg(feature = "rustls")] +#[doc(hidden)] +impl From>> for TlsStream { + fn from(value: futures_rustls::server::TlsStream>) -> Self { + Self(TlsStreamInner::Rustls(futures_rustls::TlsStream::Server( + value, + ))) + } +} + +impl AsyncRead for TlsStream { + async fn read(&mut self, mut buf: B) -> BufResult { + let slice = buf.as_mut_slice(); + slice.fill(MaybeUninit::new(0)); + // SAFETY: The memory has been initialized + let slice = + unsafe { std::slice::from_raw_parts_mut(slice.as_mut_ptr().cast(), slice.len()) }; + match &mut self.0 { + #[cfg(feature = "native-tls")] + TlsStreamInner::NativeTls(s) => loop { + match io::Read::read(s, slice) { + Ok(res) => { + unsafe { buf.set_buf_init(res) }; + return BufResult(Ok(res), buf); + } + Err(e) if e.kind() == io::ErrorKind::WouldBlock => { + match s.get_mut().fill_read_buf().await { + Ok(_) => continue, + Err(e) => return BufResult(Err(e), buf), + } + } + res => return BufResult(res, buf), + } + }, + #[cfg(feature = "rustls")] + TlsStreamInner::Rustls(s) => { + let res = futures_util::AsyncReadExt::read(s, slice).await; + let res = match res { + Ok(len) => { + unsafe { buf.set_buf_init(len) }; + Ok(len) + } + // TLS streams may return UnexpectedEof when the connection is closed. + // https://docs.rs/rustls/latest/rustls/manual/_03_howto/index.html#unexpected-eof + Err(e) if e.kind() == io::ErrorKind::UnexpectedEof => Ok(0), + _ => res, + }; + BufResult(res, buf) + } + } + } +} + +#[cfg(feature = "native-tls")] +async fn flush_impl(s: &mut native_tls::TlsStream>) -> io::Result<()> { + loop { + match io::Write::flush(s) { + Ok(()) => break, + Err(e) if e.kind() == io::ErrorKind::WouldBlock => { + s.get_mut().flush_write_buf().await?; + } + Err(e) => return Err(e), + } + } + s.get_mut().flush_write_buf().await?; + Ok(()) +} + +impl AsyncWrite for TlsStream { + async fn write(&mut self, buf: T) -> BufResult { + let slice = buf.as_slice(); + match &mut self.0 { + #[cfg(feature = "native-tls")] + TlsStreamInner::NativeTls(s) => loop { + let res = io::Write::write(s, slice); + match res { + Err(e) if e.kind() == io::ErrorKind::WouldBlock => match flush_impl(s).await { + Ok(_) => continue, + Err(e) => return BufResult(Err(e), buf), + }, + _ => return BufResult(res, buf), + } + }, + #[cfg(feature = "rustls")] + TlsStreamInner::Rustls(s) => { + let res = futures_util::AsyncWriteExt::write(s, slice).await; + BufResult(res, buf) + } + } + } + + async fn flush(&mut self) -> io::Result<()> { + match &mut self.0 { + #[cfg(feature = "native-tls")] + TlsStreamInner::NativeTls(s) => flush_impl(s).await, + #[cfg(feature = "rustls")] + TlsStreamInner::Rustls(s) => futures_util::AsyncWriteExt::flush(s).await, + } + } + + async fn shutdown(&mut self) -> io::Result<()> { + self.flush().await?; + match &mut self.0 { + #[cfg(feature = "native-tls")] + TlsStreamInner::NativeTls(s) => s.get_mut().get_mut().shutdown().await, + #[cfg(feature = "rustls")] + TlsStreamInner::Rustls(s) => futures_util::AsyncWriteExt::close(s).await, + } + } +} diff --git a/compio-tls/src/stream/mod.rs b/compio-tls/src/stream/mod.rs deleted file mode 100644 index 67db5cbf..00000000 --- a/compio-tls/src/stream/mod.rs +++ /dev/null @@ -1,188 +0,0 @@ -use std::{borrow::Cow, io, mem::MaybeUninit}; - -use compio_buf::{BufResult, IoBuf, IoBufMut}; -use compio_io::{AsyncRead, AsyncWrite, compat::SyncStream}; - -#[cfg(feature = "rustls")] -mod rtls; - -#[derive(Debug)] -#[allow(clippy::large_enum_variant)] -enum TlsStreamInner { - #[cfg(feature = "native-tls")] - NativeTls(native_tls::TlsStream>), - #[cfg(feature = "rustls")] - Rustls(rtls::TlsStream>), -} - -impl TlsStreamInner { - fn get_mut(&mut self) -> &mut SyncStream { - match self { - #[cfg(feature = "native-tls")] - Self::NativeTls(s) => s.get_mut(), - #[cfg(feature = "rustls")] - Self::Rustls(s) => s.get_mut(), - } - } - - pub fn negotiated_alpn(&self) -> Option> { - match self { - #[cfg(feature = "native-tls")] - Self::NativeTls(s) => s.negotiated_alpn().ok().flatten().map(Cow::from), - #[cfg(feature = "rustls")] - Self::Rustls(s) => s.negotiated_alpn().map(Cow::from), - } - } -} - -impl io::Read for TlsStreamInner { - fn read(&mut self, buf: &mut [u8]) -> io::Result { - match self { - #[cfg(feature = "native-tls")] - Self::NativeTls(s) => io::Read::read(s, buf), - #[cfg(feature = "rustls")] - Self::Rustls(s) => io::Read::read(s, buf), - } - } - - #[cfg(feature = "read_buf")] - fn read_buf(&mut self, buf: io::BorrowedCursor<'_>) -> io::Result<()> { - match self { - #[cfg(feature = "native-tls")] - Self::NativeTls(s) => io::Read::read_buf(s, buf), - #[cfg(feature = "rustls")] - Self::Rustls(s) => io::Read::read_buf(s, buf), - } - } -} - -impl io::Write for TlsStreamInner { - fn write(&mut self, buf: &[u8]) -> io::Result { - match self { - #[cfg(feature = "native-tls")] - Self::NativeTls(s) => io::Write::write(s, buf), - #[cfg(feature = "rustls")] - Self::Rustls(s) => io::Write::write(s, buf), - } - } - - fn flush(&mut self) -> io::Result<()> { - match self { - #[cfg(feature = "native-tls")] - Self::NativeTls(s) => io::Write::flush(s), - #[cfg(feature = "rustls")] - Self::Rustls(s) => io::Write::flush(s), - } - } -} - -/// A wrapper around an underlying raw stream which implements the TLS or SSL -/// protocol. -/// -/// A `TlsStream` represents a handshake that has been completed successfully -/// and both the server and the client are ready for receiving and sending -/// data. Bytes read from a `TlsStream` are decrypted from `S` and bytes written -/// to a `TlsStream` are encrypted when passing through to `S`. -#[derive(Debug)] -pub struct TlsStream(TlsStreamInner); - -impl TlsStream { - #[cfg(feature = "rustls")] - pub(crate) fn new_rustls_client(s: SyncStream, conn: rustls::ClientConnection) -> Self { - Self(TlsStreamInner::Rustls(rtls::TlsStream::new_client(s, conn))) - } - - #[cfg(feature = "rustls")] - pub(crate) fn new_rustls_server(s: SyncStream, conn: rustls::ServerConnection) -> Self { - Self(TlsStreamInner::Rustls(rtls::TlsStream::new_server(s, conn))) - } - - /// Returns the negotiated ALPN protocol. - pub fn negotiated_alpn(&self) -> Option> { - self.0.negotiated_alpn() - } -} - -#[cfg(feature = "native-tls")] -#[doc(hidden)] -impl From>> for TlsStream { - fn from(value: native_tls::TlsStream>) -> Self { - Self(TlsStreamInner::NativeTls(value)) - } -} - -impl AsyncRead for TlsStream { - async fn read(&mut self, mut buf: B) -> BufResult { - let slice: &mut [MaybeUninit] = buf.as_mut_slice(); - - #[cfg(feature = "read_buf")] - let mut f = { - let mut borrowed_buf = io::BorrowedBuf::from(slice); - move |s: &mut _| { - let mut cursor = borrowed_buf.unfilled(); - std::io::Read::read_buf(s, cursor.reborrow())?; - Ok::(cursor.written()) - } - }; - - #[cfg(not(feature = "read_buf"))] - let mut f = { - slice.fill(MaybeUninit::new(0)); - // SAFETY: The memory has been initialized - let slice = - unsafe { std::slice::from_raw_parts_mut(slice.as_mut_ptr().cast(), slice.len()) }; - |s: &mut _| std::io::Read::read(s, slice) - }; - - loop { - match f(&mut self.0) { - Ok(res) => { - unsafe { buf.set_buf_init(res) }; - return BufResult(Ok(res), buf); - } - Err(e) if e.kind() == io::ErrorKind::WouldBlock => { - match self.0.get_mut().fill_read_buf().await { - Ok(_) => continue, - Err(e) => return BufResult(Err(e), buf), - } - } - res => return BufResult(res, buf), - } - } - } -} - -impl AsyncWrite for TlsStream { - async fn write(&mut self, buf: T) -> BufResult { - let slice = buf.as_slice(); - loop { - let res = io::Write::write(&mut self.0, slice); - match res { - Err(e) if e.kind() == io::ErrorKind::WouldBlock => match self.flush().await { - Ok(_) => continue, - Err(e) => return BufResult(Err(e), buf), - }, - _ => return BufResult(res, buf), - } - } - } - - async fn flush(&mut self) -> io::Result<()> { - loop { - match io::Write::flush(&mut self.0) { - Ok(()) => break, - Err(e) if e.kind() == io::ErrorKind::WouldBlock => { - self.0.get_mut().flush_write_buf().await?; - } - Err(e) => return Err(e), - } - } - self.0.get_mut().flush_write_buf().await?; - Ok(()) - } - - async fn shutdown(&mut self) -> io::Result<()> { - self.flush().await?; - self.0.get_mut().get_mut().shutdown().await - } -} diff --git a/compio-tls/src/stream/rtls.rs b/compio-tls/src/stream/rtls.rs deleted file mode 100644 index bf94be9d..00000000 --- a/compio-tls/src/stream/rtls.rs +++ /dev/null @@ -1,138 +0,0 @@ -use std::io; - -use rustls::{ClientConnection, Error, IoState, Reader, ServerConnection, Writer}; - -#[derive(Debug)] -enum TlsConnection { - Client(ClientConnection), - Server(ServerConnection), -} - -impl TlsConnection { - pub fn reader(&mut self) -> Reader<'_> { - match self { - Self::Client(c) => c.reader(), - Self::Server(c) => c.reader(), - } - } - - pub fn writer(&mut self) -> Writer<'_> { - match self { - Self::Client(c) => c.writer(), - Self::Server(c) => c.writer(), - } - } - - pub fn process_new_packets(&mut self) -> Result { - match self { - Self::Client(c) => c.process_new_packets(), - Self::Server(c) => c.process_new_packets(), - } - } - - pub fn read_tls(&mut self, rd: &mut dyn io::Read) -> io::Result { - match self { - Self::Client(c) => c.read_tls(rd), - Self::Server(c) => c.read_tls(rd), - } - } - - pub fn wants_read(&self) -> bool { - match self { - Self::Client(c) => c.wants_read(), - Self::Server(c) => c.wants_read(), - } - } - - pub fn write_tls(&mut self, wr: &mut dyn io::Write) -> io::Result { - match self { - Self::Client(c) => c.write_tls(wr), - Self::Server(c) => c.write_tls(wr), - } - } - - pub fn wants_write(&self) -> bool { - match self { - Self::Client(c) => c.wants_write(), - Self::Server(c) => c.wants_write(), - } - } -} - -#[derive(Debug)] -pub struct TlsStream { - inner: S, - conn: TlsConnection, -} - -impl TlsStream { - pub fn new_client(inner: S, conn: ClientConnection) -> Self { - Self { - inner, - conn: TlsConnection::Client(conn), - } - } - - pub fn new_server(inner: S, conn: ServerConnection) -> Self { - Self { - inner, - conn: TlsConnection::Server(conn), - } - } - - pub fn get_mut(&mut self) -> &mut S { - &mut self.inner - } - - pub fn negotiated_alpn(&self) -> Option<&[u8]> { - match &self.conn { - TlsConnection::Client(client) => client.alpn_protocol(), - TlsConnection::Server(server) => server.alpn_protocol(), - } - } -} - -impl TlsStream { - fn read_impl(&mut self, mut f: impl FnMut(Reader) -> io::Result) -> io::Result { - loop { - while self.conn.wants_read() { - self.conn.read_tls(&mut self.inner)?; - self.conn.process_new_packets().map_err(io::Error::other)?; - } - - match f(self.conn.reader()) { - Ok(len) => { - return Ok(len); - } - Err(e) if e.kind() == io::ErrorKind::WouldBlock => continue, - Err(e) => return Err(e), - } - } - } -} - -impl io::Read for TlsStream { - fn read(&mut self, buf: &mut [u8]) -> io::Result { - self.read_impl(|mut reader| reader.read(buf)) - } - - #[cfg(feature = "read_buf")] - fn read_buf(&mut self, mut buf: io::BorrowedCursor<'_>) -> io::Result<()> { - self.read_impl(|mut reader| reader.read_buf(buf.reborrow())) - } -} - -impl io::Write for TlsStream { - fn write(&mut self, buf: &[u8]) -> io::Result { - self.flush()?; - self.conn.writer().write(buf) - } - - fn flush(&mut self) -> io::Result<()> { - while self.conn.wants_write() { - self.conn.write_tls(&mut self.inner)?; - } - self.inner.flush()?; - Ok(()) - } -}