From 33859a44e81df748dc1c2a883d4b1495c27dbfe9 Mon Sep 17 00:00:00 2001 From: Steve Fan <29133953+stevefan1999-personal@users.noreply.github.com> Date: Sat, 25 Jun 2022 20:00:42 +0800 Subject: [PATCH] fixup! remove anyhow requirement --- Cargo.toml | 3 +- examples/client/main.rs | 4 +- examples/client_async/main.rs | 3 +- examples/server/main.rs | 6 +- examples/server_async/main.rs | 6 +- src/framer.rs | 172 ++++++++++++++++++++++++---------- 6 files changed, 134 insertions(+), 60 deletions(-) diff --git a/Cargo.toml b/Cargo.toml index 1714e61..a6f1498 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -12,6 +12,7 @@ readme = "README.md" [dependencies] async-std = { version = "1.12.0", optional = true } +async-trait = { version = "0.1.56", optional = true } base64 = { version = "0.13.0", default-features = false } base64-simd = { version = "0.5.0", default-features = false, optional = true } byteorder = { version = "1.4.3", default-features = false } @@ -42,7 +43,7 @@ tokio-stream = { version = "0.1.9", features = ["net"] } default = ["std"] # default = [] std = [] -async = ["std"] +async = ["std", "async-trait"] tokio = ["dep:tokio", "async"] futures = ["dep:futures", "async"] smol = ["dep:smol", "async"] diff --git a/examples/client/main.rs b/examples/client/main.rs index ca0a159..b931723 100644 --- a/examples/client/main.rs +++ b/examples/client/main.rs @@ -13,9 +13,9 @@ use embedded_websocket::{ framer::{Framer, FramerError, ReadResult}, WebSocketClient, WebSocketCloseStatusCode, WebSocketOptions, WebSocketSendMessageType, }; -use std::net::TcpStream; +use std::{error::Error, net::TcpStream}; -fn main() -> Result<(), FramerError> { +fn main() -> Result<(), FramerError> { // open a TCP stream to localhost port 1337 let address = "127.0.0.1:1337"; println!("Connecting to: {}", address); diff --git a/examples/client_async/main.rs b/examples/client_async/main.rs index e60f310..ff8f959 100644 --- a/examples/client_async/main.rs +++ b/examples/client_async/main.rs @@ -13,6 +13,7 @@ use embedded_websocket::{ framer::{Framer, FramerError, ReadResult}, WebSocketClient, WebSocketCloseStatusCode, WebSocketOptions, WebSocketSendMessageType, }; +use std::error::Error; cfg_if::cfg_if! { if #[cfg(feature = "tokio")] { @@ -27,7 +28,7 @@ cfg_if::cfg_if! { #[cfg_attr(feature = "async-std", async_std::main)] #[cfg_attr(feature = "tokio", tokio::main)] #[cfg_attr(feature = "smol", smol_potat::main)] -async fn main() -> Result<(), FramerError> { +async fn main() -> Result<(), FramerError> { // open a TCP stream to localhost port 1337 let address = "127.0.0.1:1337"; println!("Connecting to: {}", address); diff --git a/examples/server/main.rs b/examples/server/main.rs index c885bbe..94349e2 100644 --- a/examples/server/main.rs +++ b/examples/server/main.rs @@ -28,7 +28,7 @@ type Result = std::result::Result; #[derive(Debug)] pub enum WebServerError { Io(std::io::Error), - Framer(FramerError), + Framer(FramerError), WebSocket(ws::Error), HttpError(String), Utf8Error, @@ -40,8 +40,8 @@ impl From for WebServerError { } } -impl From for WebServerError { - fn from(err: FramerError) -> WebServerError { +impl From> for WebServerError { + fn from(err: FramerError) -> WebServerError { WebServerError::Framer(err) } } diff --git a/examples/server_async/main.rs b/examples/server_async/main.rs index 6c43d8f..11806c2 100644 --- a/examples/server_async/main.rs +++ b/examples/server_async/main.rs @@ -46,7 +46,7 @@ cfg_if::cfg_if! { #[derive(Debug)] pub enum WebServerError { Io(std::io::Error), - Framer(FramerError), + Framer(FramerError), WebSocket(ws::Error), HttpError(String), Utf8Error, @@ -58,8 +58,8 @@ impl From for WebServerError { } } -impl From for WebServerError { - fn from(err: FramerError) -> WebServerError { +impl From> for WebServerError { + fn from(err: FramerError) -> WebServerError { WebServerError::Framer(err) } } diff --git a/src/framer.rs b/src/framer.rs index c92790d..c9a72c6 100644 --- a/src/framer.rs +++ b/src/framer.rs @@ -11,22 +11,91 @@ use crate::{ WebSocketType, }; use core::{cmp::min, str::Utf8Error}; -use core2::io::{Read, Write}; use rand_core::RngCore; -#[cfg(feature = "tokio")] -use tokio::io::{AsyncRead, AsyncReadExt, AsyncWrite, AsyncWriteExt}; +pub trait Read { + fn read(&mut self, buf: &mut [u8]) -> Result; +} -#[cfg(feature = "futures")] -use futures::io::{AsyncRead, AsyncReadExt, AsyncWrite, AsyncWriteExt}; +pub trait Write { + fn write_all(&mut self, buf: &[u8]) -> Result<(), E>; +} -#[cfg(feature = "smol")] -use smol::io::{AsyncRead, AsyncReadExt, AsyncWrite, AsyncWriteExt}; +cfg_if::cfg_if! { + if #[cfg(feature = "std")] { + impl Read for std::net::TcpStream { + fn read(&mut self, buf: &mut [u8]) -> Result { + std::io::Read::read(self, buf) + } + } -#[cfg(feature = "async-std")] -use async_std::io::{ - Read as AsyncRead, ReadExt as AsyncReadExt, Write as AsyncWrite, WriteExt as AsyncWriteExt, -}; + impl Write for std::net::TcpStream { + fn write_all(&mut self, buf: &[u8]) -> Result<(), std::io::Error> { + std::io::Write::write_all(self, buf) + } + } + } +} + +cfg_if::cfg_if! { + if #[cfg(feature = "async")] { + #[async_trait::async_trait] + pub trait AsyncRead { + async fn read(&mut self, buf: &mut [u8]) -> Result; + } + + #[async_trait::async_trait] + pub trait AsyncWrite { + async fn write_all(&mut self, buf: &[u8]) -> Result<(), E>; + } + } +} + +cfg_if::cfg_if! { + if #[cfg(feature = "tokio")] { + #[async_trait::async_trait] + impl AsyncRead for tokio::net::TcpStream { + async fn read(&mut self, buf: &mut [u8]) -> Result { + tokio::io::AsyncReadExt::read(self, buf).await + } + } + + #[async_trait::async_trait] + impl AsyncWrite for tokio::net::TcpStream { + async fn write_all(&mut self, buf: &[u8]) -> Result<(), std::io::Error> { + tokio::io::AsyncWriteExt::write_all(self, buf).await + } + } + } else if #[cfg(feature = "smol")] { + #[async_trait::async_trait] + impl AsyncRead for smol::net::TcpStream { + async fn read(&mut self, buf: &mut [u8]) -> Result { + smol::io::AsyncReadExt::read(self, buf).await + } + } + + #[async_trait::async_trait] + impl AsyncWrite for smol::net::TcpStream { + async fn write_all(&mut self, buf: &[u8]) -> Result<(), std::io::Error> { + smol::io::AsyncWriteExt::write_all(self, buf).await + } + } + } else if #[cfg(feature = "async-std")] { + #[async_trait::async_trait] + impl AsyncRead for async_std::net::TcpStream { + async fn read(&mut self, buf: &mut [u8]) -> Result { + async_std::io::ReadExt::read(self, buf).await + } + } + + #[async_trait::async_trait] + impl AsyncWrite for async_std::net::TcpStream { + async fn write_all(&mut self, buf: &[u8]) -> Result<(), std::io::Error> { + async_std::io::WriteExt::write_all(self, buf).await + } + } + } +} pub enum ReadResult<'a> { Binary(&'a [u8]), @@ -36,8 +105,8 @@ pub enum ReadResult<'a> { } #[derive(Debug)] -pub enum FramerError { - Io(core2::io::Error), +pub enum FramerError { + Io(E), FrameTooLarge(usize), Utf8(Utf8Error), HttpHeader(httparse::Error), @@ -61,11 +130,11 @@ impl<'a, TRng> Framer<'a, TRng, crate::Client> where TRng: RngCore, { - pub fn connect( + pub fn connect( &mut self, - stream: &mut (impl Read + Write), + stream: &mut (impl Read + Write), websocket_options: &WebSocketOptions, - ) -> Result, FramerError> { + ) -> Result, FramerError> { let (len, web_socket_key) = self .websocket .client_connect(websocket_options, &mut self.write_buf) @@ -108,11 +177,11 @@ impl<'a, TRng> Framer<'a, TRng, crate::Server> where TRng: RngCore, { - pub fn accept( + pub fn accept( &mut self, - stream: &mut impl Write, + stream: &mut impl Write, websocket_context: &WebSocketContext, - ) -> Result<(), FramerError> { + ) -> Result<(), FramerError> { let len = self.accept_len(&websocket_context)?; stream @@ -121,7 +190,10 @@ where Ok(()) } - fn accept_len(&mut self, websocket_context: &&WebSocketContext) -> Result { + fn accept_len( + &mut self, + websocket_context: &&WebSocketContext, + ) -> Result> { self.websocket .server_accept( &websocket_context.sec_websocket_key, @@ -159,23 +231,23 @@ where self.websocket.state } - fn close_len( + fn close_len( &mut self, close_status: WebSocketCloseStatusCode, status_description: Option<&str>, - ) -> Result { + ) -> Result> { self.websocket .close(close_status, status_description, self.write_buf) .map_err(FramerError::WebSocket) } // calling close on a websocket that has already been closed by the other party has no effect - pub fn close( + pub fn close( &mut self, - stream: &mut impl Write, + stream: &mut impl Write, close_status: WebSocketCloseStatusCode, status_description: Option<&str>, - ) -> Result<(), FramerError> { + ) -> Result<(), FramerError> { let len = self.close_len(close_status, status_description)?; stream .write_all(&self.write_buf[..len]) @@ -183,24 +255,24 @@ where Ok(()) } - fn write_len( + fn write_len( &mut self, message_type: WebSocketSendMessageType, end_of_message: bool, frame_buf: &[u8], - ) -> Result { + ) -> Result> { self.websocket .write(message_type, end_of_message, frame_buf, self.write_buf) .map_err(FramerError::WebSocket) } - pub fn write( + pub fn write( &mut self, - stream: &mut impl Write, + stream: &mut impl Write, message_type: WebSocketSendMessageType, end_of_message: bool, frame_buf: &[u8], - ) -> Result<(), FramerError> { + ) -> Result<(), FramerError> { let len = self.write_len(message_type, end_of_message, frame_buf)?; stream .write_all(&self.write_buf[..len]) @@ -211,11 +283,11 @@ where // frame_buf should be large enough to hold an entire websocket text frame // this function will block until it has recieved a full websocket frame. // It will wait until the last fragmented frame has arrived. - pub fn read<'b>( + pub fn read<'b, E>( &mut self, - stream: &mut (impl Read + Write), + stream: &mut (impl Read + Write), frame_buf: &'b mut [u8], - ) -> Result, FramerError> { + ) -> Result, FramerError> { loop { if *self.read_cursor == 0 || *self.read_cursor == self.read_len { self.read_len = stream.read(self.read_buf).map_err(FramerError::Io)?; @@ -291,13 +363,13 @@ where } } - fn send_back( + fn send_back( &mut self, - stream: &mut impl Write, + stream: &mut impl Write, frame_buf: &'_ mut [u8], len_to: usize, send_message_type: WebSocketSendMessageType, - ) -> Result<(), FramerError> { + ) -> Result<(), FramerError> { let len = self.send_back_data(&frame_buf, len_to, send_message_type)?; stream .write_all(&self.write_buf[..len]) @@ -305,12 +377,12 @@ where Ok(()) } - fn send_back_data( + fn send_back_data( &mut self, frame_buf: &&mut [u8], len_to: usize, send_message_type: WebSocketSendMessageType, - ) -> Result { + ) -> Result> { let payload_len = min(self.write_buf.len(), len_to); let from = &frame_buf[self.frame_cursor..self.frame_cursor + payload_len]; self.websocket @@ -324,11 +396,11 @@ impl<'a, TRng> Framer<'a, TRng, crate::Client> where TRng: RngCore, { - pub async fn connect_async( + pub async fn connect_async + AsyncWrite + Unpin, E>( &mut self, stream: &mut S, websocket_options: &'a WebSocketOptions<'a>, - ) -> Result, FramerError> { + ) -> Result, FramerError> { let (len, web_socket_key) = self .websocket .client_connect(websocket_options, &mut self.write_buf) @@ -374,11 +446,11 @@ impl<'a, TRng> Framer<'a, TRng, crate::Server> where TRng: RngCore, { - pub async fn accept_async( + pub async fn accept_async + Unpin, E>( &mut self, stream: &mut W, websocket_context: &WebSocketContext, - ) -> Result<(), FramerError> { + ) -> Result<(), FramerError> { let len = self.accept_len(&websocket_context)?; stream @@ -396,12 +468,12 @@ where TWebSocketType: WebSocketType, { // calling close on a websocket that has already been closed by the other party has no effect - pub async fn close_async( + pub async fn close_async + Unpin, E>( &mut self, stream: &mut W, close_status: WebSocketCloseStatusCode, status_description: Option<&str>, - ) -> Result<(), FramerError> { + ) -> Result<(), FramerError> { let len = self.close_len(close_status, status_description)?; stream .write_all(&self.write_buf[..len]) @@ -410,13 +482,13 @@ where Ok(()) } - pub async fn write_async( + pub async fn write_async + Unpin, E>( &mut self, stream: &mut W, message_type: WebSocketSendMessageType, end_of_message: bool, frame_buf: &[u8], - ) -> Result<(), FramerError> { + ) -> Result<(), FramerError> { let len = self.write_len(message_type, end_of_message, frame_buf)?; stream @@ -426,11 +498,11 @@ where Ok(()) } - pub async fn read_async<'b, S: AsyncWrite + AsyncRead + Unpin>( + pub async fn read_async<'b, S: AsyncWrite + AsyncRead + Unpin, E>( &mut self, stream: &mut S, frame_buf: &'b mut [u8], - ) -> Result, FramerError> { + ) -> Result, FramerError> { loop { if *self.read_cursor == 0 || *self.read_cursor == self.read_len { self.read_len = stream.read(self.read_buf).await.map_err(FramerError::Io)?; @@ -508,13 +580,13 @@ where } } - async fn send_back_async( + async fn send_back_async + Unpin, E>( &mut self, stream: &mut W, frame_buf: &'_ mut [u8], len_to: usize, send_message_type: WebSocketSendMessageType, - ) -> Result<(), FramerError> { + ) -> Result<(), FramerError> { let len = self.send_back_data(&frame_buf, len_to, send_message_type)?; stream .write_all(&self.write_buf[..len])