diff --git a/Cargo.toml b/Cargo.toml index 8862e03..d2ad551 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -11,18 +11,64 @@ categories = ["embedded", "no-std", "network-programming"] readme = "README.md" [dependencies] -sha1 = { version = "0.10.1", default-features = false } -heapless = "0.7.14" +async-trait = { version = "0.1.56", optional = true } +base64 = { version = "0.13.0", default-features = false } +base64-simd = { version = "0.5.0", default-features = false, optional = true } byteorder = { version = "1.4.3", default-features = false } +cfg-if = "1.0.0" +heapless = "0.7.14" httparse = { version = "1.7.1", default-features = false } rand_core = "0.6.3" -base64 = { version = "0.13.0", default-features = false } +sha1 = { version = "0.10.1", default-features = false } [dev-dependencies] +async-std = { version = "1.12.0", features = ["attributes"] } +async-trait = "0.1.56" +cfg-if = "1.0.0" +once_cell = "1.12.0" rand = "0.8.5" +route-recognizer = "0.3.1" +smol = "1.2.5" +smol-potat = { version = "1.1.2", features = ["auto"] } +tokio = { version = "1.19.2", features = ["macros", "net", "rt-multi-thread", "io-util"] } +tokio-stream = { version = "0.1.9", features = ["net"] } # see readme for no_std support [features] default = ["std"] # default = [] std = [] +async = ["std", "async-trait"] +example-tokio = ["async"] +example-smol = ["async"] +example-async-std = ["async"] + +[[example]] +name = "server_tokio" +path = "examples/server_async/main.rs" +required-features = ["example-tokio"] + +[[example]] +name = "server_smol" +path = "examples/server_async/main.rs" +required-features = ["example-smol"] + +[[example]] +name = "server_async_std" +path = "examples/server_async/main.rs" +required-features = ["example-async-std"] + +[[example]] +name = "client_tokio" +path = "examples/client_async/main.rs" +required-features = ["example-tokio"] + +[[example]] +name = "client_smol" +path = "examples/client_async/main.rs" +required-features = ["example-smol"] + +[[example]] +name = "client_async_std" +path = "examples/client_async/main.rs" +required-features = ["example-async-std"] diff --git a/examples/client.rs b/examples/client/main.rs similarity index 100% rename from examples/client.rs rename to examples/client/main.rs diff --git a/examples/client_async/compat.rs b/examples/client_async/compat.rs new file mode 100644 index 0000000..4fdf12d --- /dev/null +++ b/examples/client_async/compat.rs @@ -0,0 +1,110 @@ +#![allow(dead_code)] + +// This is an example implementation of compatibility extension, gladly stolen from futures_lite +// As this is far from being any useful, please do extend this on your own +pub trait CompatExt { + fn compat(self) -> Compat + where + Self: Sized; + fn compat_ref(&self) -> Compat<&Self>; + fn compat_mut(&mut self) -> Compat<&mut Self>; +} + +impl CompatExt for T { + fn compat(self) -> Compat + where + Self: Sized, + { + Compat(self) + } + + fn compat_ref(&self) -> Compat<&Self> { + Compat(self) + } + + fn compat_mut(&mut self) -> Compat<&mut Self> { + Compat(self) + } +} + +pub struct Compat(T); + +impl Compat { + pub fn get_ref(&self) -> &T { + &self.0 + } + + pub fn get_mut(&mut self) -> &mut T { + &mut self.0 + } + + pub fn into_inner(self) -> T { + self.0 + } +} + +#[cfg(feature = "example-tokio")] +pub mod tokio_compat { + use super::Compat; + use async_trait::async_trait; + use embedded_websocket::compat::{AsyncRead, AsyncWrite}; + use tokio::io::{AsyncReadExt, AsyncWriteExt}; + + #[async_trait] + impl AsyncRead for Compat { + async fn read(&mut self, buf: &mut [u8]) -> Result { + AsyncReadExt::read(self.get_mut(), buf).await + } + } + + #[async_trait] + impl AsyncWrite for Compat { + async fn write_all(&mut self, buf: &[u8]) -> Result<(), std::io::Error> { + AsyncWriteExt::write_all(self.get_mut(), buf).await + } + } +} + +#[cfg(feature = "example-smol")] +pub mod smol_compat { + use super::Compat; + use async_trait::async_trait; + use embedded_websocket::compat::{AsyncRead, AsyncWrite}; + use smol::io::{AsyncReadExt, AsyncWriteExt}; + + #[async_trait] + impl AsyncRead for Compat { + async fn read(&mut self, buf: &mut [u8]) -> Result { + AsyncReadExt::read(self.get_mut(), buf).await + } + } + + #[async_trait] + impl AsyncWrite for Compat { + async fn write_all(&mut self, buf: &[u8]) -> Result<(), std::io::Error> { + AsyncWriteExt::write_all(self.get_mut(), buf).await + } + } +} + +#[cfg(feature = "example-async-std")] +pub mod async_std_compat { + use super::Compat; + use async_std::io::{ReadExt, WriteExt}; + use async_trait::async_trait; + use embedded_websocket::compat::{AsyncRead, AsyncWrite}; + + #[async_trait] + impl AsyncRead for Compat { + async fn read(&mut self, buf: &mut [u8]) -> Result { + ReadExt::read(self.get_mut(), buf).await + } + } + + #[async_trait] + impl AsyncWrite for Compat { + async fn write_all(&mut self, buf: &[u8]) -> Result<(), std::io::Error> { + WriteExt::write_all(self.get_mut(), buf).await + } + } +} diff --git a/examples/client_async/main.rs b/examples/client_async/main.rs new file mode 100644 index 0000000..b24140f --- /dev/null +++ b/examples/client_async/main.rs @@ -0,0 +1,90 @@ +// The MIT License (MIT) +// Copyright (c) 2019 David Haig + +// Demo websocket client connecting to localhost port 1337. +// This will initiate a websocket connection to path /chat. The demo sends a simple "Hello, World!" +// message and expects an echo of the same message as a reply. +// It will then initiate a close handshake, wait for a close response from the server, +// and terminate the connection. +// Note that we are using the standard library in the demo and making use of the framer helper module +// but the websocket library remains no_std (see client_full for an example without the framer helper module) + +mod compat; + +use crate::compat::CompatExt; +use embedded_websocket::{ + framer::{Framer, FramerError, ReadResult}, + WebSocketClient, WebSocketCloseStatusCode, WebSocketOptions, WebSocketSendMessageType, +}; +use std::error::Error; + +cfg_if::cfg_if! { + if #[cfg(feature = "example-tokio")] { + use tokio::net::TcpStream; + } else if #[cfg(feature = "example-smol")] { + use smol::net::TcpStream; + } else if #[cfg(feature = "example-async-std")] { + use async_std::net::TcpStream; + } +} + +#[cfg_attr(feature = "example-async-std", async_std::main)] +#[cfg_attr(feature = "example-tokio", tokio::main)] +#[cfg_attr(feature = "example-smol", smol_potat::main)] +async fn main() -> Result<(), FramerError> { + // open a TCP stream to localhost port 1337 + let address = "127.0.0.1:1337"; + println!("Connecting to: {}", address); + let mut stream = TcpStream::connect(address).await.map_err(FramerError::Io)?; + println!("Connected."); + + let mut read_buf = [0; 4000]; + let mut read_cursor = 0; + let mut write_buf = [0; 4000]; + let mut frame_buf = [0; 4000]; + let mut websocket = WebSocketClient::new_client(rand::thread_rng()); + + // initiate a websocket opening handshake + let websocket_options = WebSocketOptions { + path: "/chat", + host: "localhost", + origin: "http://localhost:1337", + sub_protocols: None, + additional_headers: None, + }; + + let mut framer = Framer::new( + &mut read_buf, + &mut read_cursor, + &mut write_buf, + &mut websocket, + ); + let mut stream = stream.compat_mut(); + + framer + .connect_async(&mut stream, &websocket_options) + .await?; + + let message = "Hello, World!"; + framer + .write_async( + &mut stream, + WebSocketSendMessageType::Text, + true, + message.as_bytes(), + ) + .await?; + + while let ReadResult::Text(s) = framer.read_async(&mut stream, &mut frame_buf).await? { + println!("Received: {}", s); + + // close the websocket after receiving the first reply + framer + .close_async(&mut stream, WebSocketCloseStatusCode::NormalClosure, None) + .await?; + println!("Sent close handshake"); + } + + println!("Connection closed"); + Ok(()) +} diff --git a/examples/server.rs b/examples/server/main.rs similarity index 68% rename from examples/server.rs rename to examples/server/main.rs index 29f18d8..3397986 100644 --- a/examples/server.rs +++ b/examples/server/main.rs @@ -3,23 +3,24 @@ // Demo websocket server that listens on localhost port 1337. // If accessed from a browser it will return a web page that will automatically attempt to -// open a websocket connection to itself. Alternatively, the client.rs example can be used to +// open a websocket connection to itself. Alternatively, the main example can be used to // open a websocket connection directly. The server will echo all Text and Ping messages back to // the client as well as responding to any opening and closing handshakes. // Note that we are using the standard library in the demo but the websocket library remains no_std use embedded_websocket as ws; -use std::net::{TcpListener, TcpStream}; -use std::str::Utf8Error; -use std::thread; +use httparse::Request; +use once_cell::sync::Lazy; +use route_recognizer::Router; use std::{ io::{Read, Write}, - usize, + net::{TcpListener, TcpStream}, + str::Utf8Error, + thread, }; -use ws::framer::ReadResult; use ws::{ - framer::{Framer, FramerError}, - WebSocketContext, WebSocketSendMessageType, WebSocketServer, + framer::{Framer, FramerError, ReadResult}, + WebSocketSendMessageType, WebSocketServer, }; type Result = std::result::Result; @@ -29,7 +30,9 @@ pub enum WebServerError { Io(std::io::Error), Framer(FramerError), WebSocket(ws::Error), + HttpError(httparse::Error), Utf8Error, + Custom(String), } impl From for WebServerError { @@ -56,6 +59,12 @@ impl From for WebServerError { } } +impl From for WebServerError { + fn from(err: httparse::Error) -> WebServerError { + WebServerError::HttpError(err) + } +} + fn main() -> std::io::Result<()> { let addr = "127.0.0.1:1337"; let listener = TcpListener::bind(addr)?; @@ -77,13 +86,24 @@ fn main() -> std::io::Result<()> { Ok(()) } -fn handle_client(mut stream: TcpStream) -> Result<()> { - println!("Client connected {}", stream.peer_addr()?); - let mut read_buf = [0; 4000]; - let mut read_cursor = 0; +type Handler = Box Result<()> + Send + Sync>; - if let Some(websocket_context) = read_header(&mut stream, &mut read_buf, &mut read_cursor)? { +static ROUTER: Lazy> = Lazy::new(|| { + let mut router = Router::new(); + router.add("/chat", Box::new(handle_chat) as Handler); + router.add("/", Box::new(handle_root) as Handler); + router +}); + +fn handle_chat(stream: &mut TcpStream, req: &Request) -> Result<()> { + println!("Received chat request: {:?}", req.path); + + if let Some(websocket_context) = + ws::read_http_header(req.headers.iter().map(|f| (f.name, f.value)))? + { // this is a websocket upgrade HTTP request + let mut read_buf = [0; 4000]; + let mut read_cursor = 0; let mut write_buf = [0; 4000]; let mut frame_buf = [0; 4000]; let mut websocket = WebSocketServer::new_server(); @@ -95,85 +115,65 @@ fn handle_client(mut stream: TcpStream) -> Result<()> { ); // complete the opening handshake with the client - framer.accept(&mut stream, &websocket_context)?; + framer.accept(stream, &websocket_context)?; println!("Websocket connection opened"); // read websocket frames - while let ReadResult::Text(text) = framer.read(&mut stream, &mut frame_buf)? { + while let ReadResult::Text(text) = framer.read(stream, &mut frame_buf)? { println!("Received: {}", text); // send the text back to the client framer.write( - &mut stream, + stream, WebSocketSendMessageType::Text, true, - text.as_bytes(), + format!("hello {}", text).as_bytes(), )? } println!("Closing websocket connection"); - Ok(()) - } else { - Ok(()) } + + Ok(()) } -fn read_header( - stream: &mut TcpStream, - read_buf: &mut [u8], - read_cursor: &mut usize, -) -> Result> { - loop { - let mut headers = [httparse::EMPTY_HEADER; 16]; - let mut request = httparse::Request::new(&mut headers); - - let received_size = stream.read(&mut read_buf[*read_cursor..])?; - - match request - .parse(&read_buf[..*read_cursor + received_size]) - .unwrap() - { - httparse::Status::Complete(len) => { - // if we read exactly the right amount of bytes for the HTTP header then read_cursor would be 0 - *read_cursor += received_size - len; - let headers = request.headers.iter().map(|f| (f.name, f.value)); - match ws::read_http_header(headers)? { - Some(websocket_context) => match request.path { - Some("/chat") => { - return Ok(Some(websocket_context)); - } - _ => return_404_not_found(stream, request.path)?, - }, - None => { - handle_non_websocket_http_request(stream, request.path)?; - } - } - return Ok(None); - } - // keep reading while the HTTP header is incomplete - httparse::Status::Partial => *read_cursor += received_size, - } - } +fn handle_root(stream: &mut TcpStream, _req: &Request) -> Result<()> { + stream.write_all(&ROOT_HTML.as_bytes())?; + Ok(()) } -fn handle_non_websocket_http_request(stream: &mut TcpStream, path: Option<&str>) -> Result<()> { - println!("Received file request: {:?}", path); +fn handle_client(mut stream: TcpStream) -> Result<()> { + println!("Client connected {}", stream.peer_addr()?); + let mut read_buf = [0; 4000]; + let mut read_cursor = 0; - match path { - Some("/") => stream.write_all(&ROOT_HTML.as_bytes())?, - unknown_path => { - return_404_not_found(stream, unknown_path)?; + let mut headers = vec![httparse::EMPTY_HEADER; 8]; + let received_size = stream.read(&mut read_buf[read_cursor..])?; + let request = loop { + let mut request = Request::new(&mut headers); + match request.parse(&read_buf[..read_cursor + received_size]) { + Ok(httparse::Status::Partial) => read_cursor += received_size, + Ok(httparse::Status::Complete(_)) => break request, + Err(httparse::Error::TooManyHeaders) => { + headers.resize(headers.len() * 2, httparse::EMPTY_HEADER) + } + Err(e) => return Err(e.into()), } }; - Ok(()) -} - -fn return_404_not_found(stream: &mut TcpStream, unknown_path: Option<&str>) -> Result<()> { - println!("Unknown path: {:?}", unknown_path); - let html = "HTTP/1.1 404 Not Found\r\nContent-Length: 0\r\nConnection: close\r\n\r\n"; - stream.write_all(&html.as_bytes())?; - Ok(()) + match ROUTER.recognize(request.path.unwrap_or("/")) { + Ok(handler) => handler.handler()(&mut stream, &request), + Err(e) => { + println!("Unknown path: {:?}", request.path); + let html = format!( + "HTTP/1.1 404 Not Found\r\nContent-Length: {len}\r\nConnection: close\r\n\r\n{msg}", + len = e.len(), + msg = e + ); + stream.write_all(&html.as_bytes())?; + Err(WebServerError::Custom(e)) + } + } } const ROOT_HTML : &str = "HTTP/1.1 200 OK\r\nContent-Type: text/html; charset=UTF-8\r\nContent-Length: 2590\r\nConnection: close\r\n\r\n diff --git a/examples/server_async/compat.rs b/examples/server_async/compat.rs new file mode 100644 index 0000000..4fdf12d --- /dev/null +++ b/examples/server_async/compat.rs @@ -0,0 +1,110 @@ +#![allow(dead_code)] + +// This is an example implementation of compatibility extension, gladly stolen from futures_lite +// As this is far from being any useful, please do extend this on your own +pub trait CompatExt { + fn compat(self) -> Compat + where + Self: Sized; + fn compat_ref(&self) -> Compat<&Self>; + fn compat_mut(&mut self) -> Compat<&mut Self>; +} + +impl CompatExt for T { + fn compat(self) -> Compat + where + Self: Sized, + { + Compat(self) + } + + fn compat_ref(&self) -> Compat<&Self> { + Compat(self) + } + + fn compat_mut(&mut self) -> Compat<&mut Self> { + Compat(self) + } +} + +pub struct Compat(T); + +impl Compat { + pub fn get_ref(&self) -> &T { + &self.0 + } + + pub fn get_mut(&mut self) -> &mut T { + &mut self.0 + } + + pub fn into_inner(self) -> T { + self.0 + } +} + +#[cfg(feature = "example-tokio")] +pub mod tokio_compat { + use super::Compat; + use async_trait::async_trait; + use embedded_websocket::compat::{AsyncRead, AsyncWrite}; + use tokio::io::{AsyncReadExt, AsyncWriteExt}; + + #[async_trait] + impl AsyncRead for Compat { + async fn read(&mut self, buf: &mut [u8]) -> Result { + AsyncReadExt::read(self.get_mut(), buf).await + } + } + + #[async_trait] + impl AsyncWrite for Compat { + async fn write_all(&mut self, buf: &[u8]) -> Result<(), std::io::Error> { + AsyncWriteExt::write_all(self.get_mut(), buf).await + } + } +} + +#[cfg(feature = "example-smol")] +pub mod smol_compat { + use super::Compat; + use async_trait::async_trait; + use embedded_websocket::compat::{AsyncRead, AsyncWrite}; + use smol::io::{AsyncReadExt, AsyncWriteExt}; + + #[async_trait] + impl AsyncRead for Compat { + async fn read(&mut self, buf: &mut [u8]) -> Result { + AsyncReadExt::read(self.get_mut(), buf).await + } + } + + #[async_trait] + impl AsyncWrite for Compat { + async fn write_all(&mut self, buf: &[u8]) -> Result<(), std::io::Error> { + AsyncWriteExt::write_all(self.get_mut(), buf).await + } + } +} + +#[cfg(feature = "example-async-std")] +pub mod async_std_compat { + use super::Compat; + use async_std::io::{ReadExt, WriteExt}; + use async_trait::async_trait; + use embedded_websocket::compat::{AsyncRead, AsyncWrite}; + + #[async_trait] + impl AsyncRead for Compat { + async fn read(&mut self, buf: &mut [u8]) -> Result { + ReadExt::read(self.get_mut(), buf).await + } + } + + #[async_trait] + impl AsyncWrite for Compat { + async fn write_all(&mut self, buf: &[u8]) -> Result<(), std::io::Error> { + WriteExt::write_all(self.get_mut(), buf).await + } + } +} diff --git a/examples/server_async/main.rs b/examples/server_async/main.rs new file mode 100644 index 0000000..2febd7a --- /dev/null +++ b/examples/server_async/main.rs @@ -0,0 +1,309 @@ +// The MIT License (MIT) +// Copyright (c) 2019 David Haig + +// Demo websocket server that listens on localhost port 1337. +// If accessed from a browser it will return a web page that will automatically attempt to +// open a websocket connection to itself. Alternatively, the main example can be used to +// open a websocket connection directly. The server will echo all Text and Ping messages back to +// the client as well as responding to any opening and closing handshakes. +// Note that we are using the standard library in the demo but the websocket library remains no_std + +mod compat; + +use crate::compat::CompatExt; +use async_trait::async_trait; +use embedded_websocket as ws; +use httparse::Request; +use once_cell::sync::Lazy; +use route_recognizer::Router; +use std::str::Utf8Error; +use ws::{ + framer::{Framer, FramerError, ReadResult}, + WebSocketSendMessageType, WebSocketServer, +}; + +type Result = std::result::Result; + +cfg_if::cfg_if! { + if #[cfg(feature = "example-tokio")] { + use tokio::{ + io::{AsyncReadExt, AsyncWriteExt}, + net::{TcpListener, TcpStream}, + }; + use tokio_stream::{wrappers::TcpListenerStream, StreamExt}; + } else if #[cfg(feature = "example-smol")] { + use smol::{ + io::{AsyncReadExt, AsyncWriteExt}, + net::{TcpListener, TcpStream}, + stream::StreamExt, + }; + } else if #[cfg(feature = "example-async-std")] { + use async_std::{ + io::{ReadExt as AsyncReadExt, WriteExt as AsyncWriteExt}, + net::{TcpListener, TcpStream}, + stream::StreamExt, + }; + } +} + +#[derive(Debug)] +pub enum WebServerError { + Io(std::io::Error), + Framer(FramerError), + WebSocket(ws::Error), + HttpError(httparse::Error), + Utf8Error, + Custom(String), +} + +impl From for WebServerError { + fn from(err: std::io::Error) -> WebServerError { + WebServerError::Io(err) + } +} + +impl From> for WebServerError { + fn from(err: FramerError) -> WebServerError { + WebServerError::Framer(err) + } +} + +impl From for WebServerError { + fn from(err: ws::Error) -> WebServerError { + WebServerError::WebSocket(err) + } +} + +impl From for WebServerError { + fn from(_: Utf8Error) -> WebServerError { + WebServerError::Utf8Error + } +} + +impl From for WebServerError { + fn from(err: httparse::Error) -> WebServerError { + WebServerError::HttpError(err) + } +} + +#[cfg_attr(feature = "example-async-std", async_std::main)] +#[cfg_attr(feature = "example-tokio", tokio::main)] +#[cfg_attr(feature = "example-smol", smol_potat::main)] +async fn main() -> std::io::Result<()> { + let addr = "127.0.0.1:1337"; + let listener = TcpListener::bind(addr).await?; + println!("Listening on: {}", addr); + + let mut incoming = { + cfg_if::cfg_if! { + if #[cfg(feature = "example-tokio")] { + TcpListenerStream::new(listener) + } else { + listener.incoming() + } + } + }; + + while let Some(stream) = incoming.next().await { + match stream { + Ok(stream) => { + let fut = async { + match handle_client(stream).await { + Ok(()) => println!("Connection closed"), + Err(e) => println!("Error: {:?}", e), + } + }; + + #[cfg(feature = "example-async-std")] + async_std::task::spawn(fut); + + #[cfg(feature = "example-smol")] + smol::spawn(fut).detach(); + + #[cfg(feature = "example-tokio")] + tokio::spawn(fut); + } + Err(e) => println!("Failed to establish a connection: {}", e), + } + } + + Ok(()) +} + +type Handler = Box; + +static ROUTER: Lazy> = Lazy::new(|| { + let mut router = Router::new(); + router.add("/chat", Box::new(Chat) as Handler); + router.add("/", Box::new(Root) as Handler); + router +}); + +#[async_trait] +trait SimpleHandler { + async fn handle(&self, stream: &mut TcpStream, req: &Request<'_, '_>) -> Result<()>; +} + +struct Chat; + +#[async_trait] +impl SimpleHandler for Chat { + async fn handle(&self, stream: &mut TcpStream, req: &Request<'_, '_>) -> Result<()> { + println!("Received chat request: {:?}", req.path); + + if let Some(websocket_context) = + ws::read_http_header(req.headers.iter().map(|f| (f.name, f.value)))? + { + // this is a websocket upgrade HTTP request + let mut read_buf = [0; 4000]; + let mut read_cursor = 0; + let mut write_buf = [0; 4000]; + let mut frame_buf = [0; 4000]; + let mut websocket = WebSocketServer::new_server(); + let mut framer = Framer::new( + &mut read_buf, + &mut read_cursor, + &mut write_buf, + &mut websocket, + ); + + // complete the opening handshake with the client + let mut stream = stream.compat_mut(); + framer.accept_async(&mut stream, &websocket_context).await?; + println!("Websocket connection opened"); + + // read websocket frames + while let ReadResult::Text(text) = + framer.read_async(&mut stream, &mut frame_buf).await? + { + println!("Received: {}", text); + + // send the text back to the client + framer + .write_async( + &mut stream, + WebSocketSendMessageType::Text, + true, + format!("hello {}", text).as_bytes(), + ) + .await? + } + + println!("Closing websocket connection"); + } + + Ok(()) + } +} + +struct Root; + +#[async_trait] +impl SimpleHandler for Root { + async fn handle(&self, stream: &mut TcpStream, _req: &Request<'_, '_>) -> Result<()> { + stream.write_all(&ROOT_HTML.as_bytes()).await?; + Ok(()) + } +} + +async fn handle_client(mut stream: TcpStream) -> Result<()> { + println!("Client connected {}", stream.peer_addr()?); + let mut read_buf = [0; 4000]; + let mut read_cursor = 0; + + let mut headers = vec![httparse::EMPTY_HEADER; 8]; + let received_size = stream.read(&mut read_buf[read_cursor..]).await?; + let request = loop { + let mut request = Request::new(&mut headers); + match request.parse(&read_buf[..read_cursor + received_size]) { + Ok(httparse::Status::Partial) => read_cursor += received_size, + Ok(httparse::Status::Complete(_)) => break request, + Err(httparse::Error::TooManyHeaders) => { + headers.resize(headers.len() * 2, httparse::EMPTY_HEADER) + } + Err(e) => return Err(e.into()), + } + }; + + match ROUTER.recognize(request.path.unwrap_or("/")) { + Ok(handler) => handler.handler().handle(&mut stream, &request).await, + Err(e) => { + println!("Unknown path: {:?}", request.path); + let html = format!( + "HTTP/1.1 404 Not Found\r\nContent-Length: {len}\r\nConnection: close\r\n\r\n{msg}", + len = e.len(), + msg = e + ); + stream.write_all(&html.as_bytes()).await?; + Err(WebServerError::Custom(e)) + } + } +} + +const ROOT_HTML : &str = "HTTP/1.1 200 OK\r\nContent-Type: text/html; charset=UTF-8\r\nContent-Length: 2590\r\nConnection: close\r\n\r\n + + + + + + + + Web Socket Demo + + + +
    +
    + +
    + + + +"; diff --git a/src/compat.rs b/src/compat.rs new file mode 100644 index 0000000..9fbefa2 --- /dev/null +++ b/src/compat.rs @@ -0,0 +1,40 @@ +use core::result::Result; +// use core::pin::Pin; + +pub trait Read { + fn read(&mut self, buf: &mut [u8]) -> Result; +} + +pub trait Write { + fn write_all(&mut self, buf: &[u8]) -> Result<(), E>; +} + +cfg_if::cfg_if! { + if #[cfg(feature = "async")] { + #[async_trait::async_trait] + pub trait AsyncRead { + async fn read(&mut self, buf: &mut [u8]) -> Result; + } + + #[async_trait::async_trait] + pub trait AsyncWrite { + async fn write_all(&mut self, buf: &[u8]) -> Result<(), E>; + } + } +} + +cfg_if::cfg_if! { + if #[cfg(feature = "std")] { + impl Read for std::net::TcpStream { + fn read(&mut self, buf: &mut [u8]) -> Result { + std::io::Read::read(self, buf) + } + } + + impl Write for std::net::TcpStream { + fn write_all(&mut self, buf: &[u8]) -> Result<(), std::io::Error> { + std::io::Write::write_all(self, buf) + } + } + } +} diff --git a/src/framer.rs b/src/framer.rs index 869cabf..c03977f 100644 --- a/src/framer.rs +++ b/src/framer.rs @@ -2,35 +2,17 @@ // This is the most common use case when working with websockets and is recommended due to the hand shaky nature of // the protocol as well as the fact that an input buffer can contain multiple websocket frames or maybe only a fragment of one. // This module allows you to work with discrete websocket frames rather than the multiple fragments you read off a stream. -// NOTE: if you are using the standard library then you can use the built in Read and Write traits from std otherwise -// you have to implement the Read and Write traits specified below +// NOTE: if you are using the standard library then you can use the built in compat::Read and compat::Write traits from std otherwise +// you have to implement the compat::Read and compat::Write traits specified below use crate::{ - WebSocket, WebSocketCloseStatusCode, WebSocketContext, WebSocketOptions, + compat, WebSocket, WebSocketCloseStatusCode, WebSocketContext, WebSocketOptions, WebSocketReceiveMessageType, WebSocketSendMessageType, WebSocketState, WebSocketSubProtocol, WebSocketType, }; use core::{cmp::min, str::Utf8Error}; use rand_core::RngCore; -// automagically implement the Stream trait for TcpStream if we are using the standard library -// if you were using no_std you would have to implement your own stream -#[cfg(feature = "std")] -impl Stream for std::net::TcpStream { - fn read(&mut self, buf: &mut [u8]) -> Result { - std::io::Read::read(self, buf) - } - - fn write_all(&mut self, buf: &[u8]) -> Result<(), std::io::Error> { - std::io::Write::write_all(self, buf) - } -} - -pub trait Stream { - fn read(&mut self, buf: &mut [u8]) -> Result; - fn write_all(&mut self, buf: &[u8]) -> Result<(), E>; -} - pub enum ReadResult<'a> { Binary(&'a [u8]), Text(&'a str), @@ -66,7 +48,7 @@ where { pub fn connect( &mut self, - stream: &mut impl Stream, + stream: &mut (impl compat::Read + compat::Write), websocket_options: &WebSocketOptions, ) -> Result, FramerError> { let (len, web_socket_key) = self @@ -113,23 +95,29 @@ where { pub fn accept( &mut self, - stream: &mut impl Stream, + stream: &mut impl compat::Write, websocket_context: &WebSocketContext, ) -> Result<(), FramerError> { - let len = self - .websocket - .server_accept( - &websocket_context.sec_websocket_key, - None, - &mut self.write_buf, - ) - .map_err(FramerError::WebSocket)?; + let len = self.accept_len(&websocket_context)?; stream .write_all(&self.write_buf[..len]) .map_err(FramerError::Io)?; Ok(()) } + + fn accept_len( + &mut self, + websocket_context: &&WebSocketContext, + ) -> Result> { + self.websocket + .server_accept( + &websocket_context.sec_websocket_key, + None, + &mut self.write_buf, + ) + .map_err(FramerError::WebSocket) + } } impl<'a, TRng, TWebSocketType> Framer<'a, TRng, TWebSocketType> @@ -159,34 +147,49 @@ where self.websocket.state } + fn close_len( + &mut self, + close_status: WebSocketCloseStatusCode, + status_description: Option<&str>, + ) -> Result> { + self.websocket + .close(close_status, status_description, self.write_buf) + .map_err(FramerError::WebSocket) + } + // calling close on a websocket that has already been closed by the other party has no effect pub fn close( &mut self, - stream: &mut impl Stream, + stream: &mut impl compat::Write, close_status: WebSocketCloseStatusCode, status_description: Option<&str>, ) -> Result<(), FramerError> { - let len = self - .websocket - .close(close_status, status_description, self.write_buf) - .map_err(FramerError::WebSocket)?; + let len = self.close_len(close_status, status_description)?; stream .write_all(&self.write_buf[..len]) .map_err(FramerError::Io)?; Ok(()) } + fn write_len( + &mut self, + message_type: WebSocketSendMessageType, + end_of_message: bool, + frame_buf: &[u8], + ) -> Result> { + self.websocket + .write(message_type, end_of_message, frame_buf, self.write_buf) + .map_err(FramerError::WebSocket) + } + pub fn write( &mut self, - stream: &mut impl Stream, + stream: &mut impl compat::Write, message_type: WebSocketSendMessageType, end_of_message: bool, frame_buf: &[u8], ) -> Result<(), FramerError> { - let len = self - .websocket - .write(message_type, end_of_message, frame_buf, self.write_buf) - .map_err(FramerError::WebSocket)?; + let len = self.write_len(message_type, end_of_message, frame_buf)?; stream .write_all(&self.write_buf[..len]) .map_err(FramerError::Io)?; @@ -198,7 +201,7 @@ where // It will wait until the last fragmented frame has arrived. pub fn read<'b, E>( &mut self, - stream: &mut impl Stream, + stream: &mut (impl compat::Read + compat::Write), frame_buf: &'b mut [u8], ) -> Result, FramerError> { loop { @@ -278,19 +281,232 @@ where fn send_back( &mut self, - stream: &mut impl Stream, + stream: &mut impl compat::Write, frame_buf: &'_ mut [u8], len_to: usize, send_message_type: WebSocketSendMessageType, ) -> Result<(), FramerError> { + let len = self.send_back_data(&frame_buf, len_to, send_message_type)?; + stream + .write_all(&self.write_buf[..len]) + .map_err(FramerError::Io)?; + Ok(()) + } + + fn send_back_data( + &mut self, + frame_buf: &&mut [u8], + len_to: usize, + send_message_type: WebSocketSendMessageType, + ) -> Result> { let payload_len = min(self.write_buf.len(), len_to); let from = &frame_buf[self.frame_cursor..self.frame_cursor + payload_len]; - let len = self - .websocket + self.websocket .write(send_message_type, true, from, &mut self.write_buf) + .map_err(FramerError::WebSocket) + } +} + +#[cfg(feature = "async")] +impl<'a, TRng> Framer<'a, TRng, crate::Client> +where + TRng: RngCore, +{ + pub async fn connect_async + compat::AsyncWrite + Unpin, E>( + &mut self, + stream: &mut S, + websocket_options: &'a WebSocketOptions<'a>, + ) -> Result, FramerError> { + let (len, web_socket_key) = self + .websocket + .client_connect(websocket_options, &mut self.write_buf) .map_err(FramerError::WebSocket)?; stream .write_all(&self.write_buf[..len]) + .await + .map_err(FramerError::Io)?; + *self.read_cursor = 0; + + loop { + // read the response from the server and check it to complete the opening handshake + let received_size = stream + .read(&mut self.read_buf[*self.read_cursor..]) + .await + .map_err(FramerError::Io)?; + + match self.websocket.client_accept( + &web_socket_key, + &self.read_buf[..*self.read_cursor + received_size], + ) { + Ok((len, sub_protocol)) => { + // "consume" the HTTP header that we have read from the stream + // read_cursor would be 0 if we exactly read the HTTP header from the stream and nothing else + *self.read_cursor += received_size - len; + return Ok(sub_protocol); + } + Err(crate::Error::HttpHeaderIncomplete) => { + *self.read_cursor += received_size; + // continue reading HTTP header in loop + } + Err(e) => { + *self.read_cursor += received_size; + return Err(FramerError::WebSocket(e)); + } + } + } + } +} + +#[cfg(feature = "async")] +impl<'a, TRng> Framer<'a, TRng, crate::Server> +where + TRng: RngCore, +{ + pub async fn accept_async + Unpin, E>( + &mut self, + stream: &mut W, + websocket_context: &WebSocketContext, + ) -> Result<(), FramerError> { + let len = self.accept_len(&websocket_context)?; + + stream + .write_all(&self.write_buf[..len]) + .await + .map_err(FramerError::Io)?; + Ok(()) + } +} + +#[cfg(feature = "async")] +impl<'a, TRng, TWebSocketType> Framer<'a, TRng, TWebSocketType> +where + TRng: RngCore, + TWebSocketType: WebSocketType, +{ + // calling close on a websocket that has already been closed by the other party has no effect + pub async fn close_async + Unpin, E>( + &mut self, + stream: &mut W, + close_status: WebSocketCloseStatusCode, + status_description: Option<&str>, + ) -> Result<(), FramerError> { + let len = self.close_len(close_status, status_description)?; + stream + .write_all(&self.write_buf[..len]) + .await + .map_err(FramerError::Io)?; + Ok(()) + } + + pub async fn write_async + Unpin, E>( + &mut self, + stream: &mut W, + message_type: WebSocketSendMessageType, + end_of_message: bool, + frame_buf: &[u8], + ) -> Result<(), FramerError> { + let len = self.write_len(message_type, end_of_message, frame_buf)?; + + stream + .write_all(&self.write_buf[..len]) + .await + .map_err(FramerError::Io)?; + Ok(()) + } + + pub async fn read_async<'b, S: compat::AsyncWrite + compat::AsyncRead + Unpin, E>( + &mut self, + stream: &mut S, + frame_buf: &'b mut [u8], + ) -> Result, FramerError> { + loop { + if *self.read_cursor == 0 || *self.read_cursor == self.read_len { + self.read_len = stream.read(self.read_buf).await.map_err(FramerError::Io)?; + *self.read_cursor = 0; + } + + if self.read_len == 0 { + return Ok(ReadResult::Closed); + } + + loop { + if *self.read_cursor == self.read_len { + break; + } + + if self.frame_cursor == frame_buf.len() { + return Err(FramerError::FrameTooLarge(frame_buf.len())); + } + + let ws_result = self + .websocket + .read( + &self.read_buf[*self.read_cursor..self.read_len], + &mut frame_buf[self.frame_cursor..], + ) + .map_err(FramerError::WebSocket)?; + + *self.read_cursor += ws_result.len_from; + + match ws_result.message_type { + WebSocketReceiveMessageType::Binary => { + self.frame_cursor += ws_result.len_to; + if ws_result.end_of_message { + let frame = &frame_buf[..self.frame_cursor]; + self.frame_cursor = 0; + return Ok(ReadResult::Binary(frame)); + } + } + WebSocketReceiveMessageType::Text => { + self.frame_cursor += ws_result.len_to; + if ws_result.end_of_message { + let frame = &frame_buf[..self.frame_cursor]; + self.frame_cursor = 0; + let text = core::str::from_utf8(frame).map_err(FramerError::Utf8)?; + return Ok(ReadResult::Text(text)); + } + } + WebSocketReceiveMessageType::CloseMustReply => { + self.send_back_async( + stream, + frame_buf, + ws_result.len_to, + WebSocketSendMessageType::CloseReply, + ) + .await?; + return Ok(ReadResult::Closed); + } + WebSocketReceiveMessageType::CloseCompleted => return Ok(ReadResult::Closed), + WebSocketReceiveMessageType::Ping => { + self.send_back_async( + stream, + frame_buf, + ws_result.len_to, + WebSocketSendMessageType::Pong, + ) + .await?; + } + WebSocketReceiveMessageType::Pong => { + let bytes = + &frame_buf[self.frame_cursor..self.frame_cursor + ws_result.len_to]; + return Ok(ReadResult::Pong(bytes)); + } + } + } + } + } + + async fn send_back_async + Unpin, E>( + &mut self, + stream: &mut W, + frame_buf: &'_ mut [u8], + len_to: usize, + send_message_type: WebSocketSendMessageType, + ) -> Result<(), FramerError> { + let len = self.send_back_data(&frame_buf, len_to, send_message_type)?; + stream + .write_all(&self.write_buf[..len]) + .await .map_err(FramerError::Io)?; Ok(()) } diff --git a/src/http.rs b/src/http.rs index 7d60836..c7983c7 100644 --- a/src/http.rs +++ b/src/http.rs @@ -135,7 +135,16 @@ pub fn build_connect_handshake_request( let mut key: [u8; 16] = [0; 16]; rng.fill_bytes(&mut key); - base64::encode_config_slice(&key, base64::STANDARD, &mut key_as_base64); + + cfg_if::cfg_if! { + if #[cfg(feature = "base64-simd")] { + use base64_simd::{Base64, OutBuf}; + Base64::STANDARD.encode(&key, OutBuf::from_slice_mut(&mut key_as_base64))?; + } else { + base64::encode_config_slice(&key, base64::STANDARD, &mut key_as_base64); + } + } + let sec_websocket_key: String<24> = String::from(str::from_utf8(&key_as_base64)?); http_request.push_str("GET ")?; @@ -212,6 +221,15 @@ pub fn build_accept_string(sec_websocket_key: &WebSocketKey, output: &mut [u8]) let mut sha1 = Sha1::new(); sha1.update(&accept_string); let input = sha1.finalize(); - base64::encode_config_slice(&input, base64::STANDARD, output); // no need for slices since the output WILL be 28 bytes + + cfg_if::cfg_if! { + if #[cfg(feature = "base64-simd")] { + use base64_simd::{Base64, OutBuf}; + Base64::STANDARD.encode(&input, OutBuf::from_slice_mut(output))?; + } else { + base64::encode_config_slice(&input, base64::STANDARD, output); // no need for slices since the output WILL be 28 bytes + } + } + Ok(()) } diff --git a/src/lib.rs b/src/lib.rs index c3f42d8..35c4899 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -26,6 +26,7 @@ pub use self::random::EmptyRng; // support for working with discrete websocket frames when using IO streams // start here!! +pub mod compat; pub mod framer; const MASK_KEY_LEN: usize = 4; @@ -211,6 +212,8 @@ pub enum Error { ConvertInfallible, RandCore, UnexpectedContinuationFrame, + #[cfg(feature = "base64-simd")] + Base64Error, } impl From for Error { @@ -237,6 +240,13 @@ impl From<()> for Error { } } +#[cfg(feature = "base64-simd")] +impl From for Error { + fn from(_: base64_simd::Error) -> Error { + Error::Base64Error + } +} + #[derive(Copy, Clone, Debug, PartialEq)] enum WebSocketOpCode { ContinuationFrame = 0,