From 1af1e1c91c293a569eb0569c406a88dc07895c61 Mon Sep 17 00:00:00 2001 From: kanarus Date: Thu, 25 Apr 2024 18:19:53 +0900 Subject: [PATCH] Improve request parsing (#131) * @2024-04-24 17:27+9:00 * Update session to handle errors in request parsing * not DEBUG * improve session * Clear warnings --- examples/hello/src/main.rs | 2 +- ohkami/src/request/_test_headers.rs | 1 + ohkami/src/request/headers.rs | 5 ++ ohkami/src/request/mod.rs | 107 ++++++++++++++-------------- ohkami/src/request/path.rs | 8 +-- ohkami/src/session/mod.rs | 32 +++++---- ohkami/src/testing/mod.rs | 8 ++- 7 files changed, 89 insertions(+), 74 deletions(-) diff --git a/examples/hello/src/main.rs b/examples/hello/src/main.rs index 330301fa..9ae6174b 100644 --- a/examples/hello/src/main.rs +++ b/examples/hello/src/main.rs @@ -8,7 +8,7 @@ mod health_handler { mod hello_handler { - use ohkami::{Response, Status}; + use ohkami::Response; use ohkami::typed::{Payload, Query}; use ohkami::builtin::payload::JSON; diff --git a/ohkami/src/request/_test_headers.rs b/ohkami/src/request/_test_headers.rs index 9c9833ea..4e954c72 100644 --- a/ohkami/src/request/_test_headers.rs +++ b/ohkami/src/request/_test_headers.rs @@ -1,3 +1,4 @@ +#![cfg(any(feature="testing", feature="DEBUG"))] #![cfg(any(feature="rt_tokio",feature="rt_async-std",feature="rt_worker"))] use std::borrow::Cow; diff --git a/ohkami/src/request/headers.rs b/ohkami/src/request/headers.rs index 2de0fe51..31d8efad 100644 --- a/ohkami/src/request/headers.rs +++ b/ohkami/src/request/headers.rs @@ -334,6 +334,11 @@ impl Headers { #[cfg(any(feature="rt_tokio",feature="rt_async-std",feature="rt_worker"))] impl Headers { + #[allow(unused)] + #[inline] pub(crate) fn get_raw(&self, name: Header) -> Option<&CowSlice> { + unsafe {self.standard.get_unchecked(name as usize)}.as_ref() + } + #[inline] pub(crate) fn insert_custom(&mut self, name: CowSlice, value: CowSlice) { match &mut self.custom { Some(c) => {c.insert(name, value);} diff --git a/ohkami/src/request/mod.rs b/ohkami/src/request/mod.rs index c2398631..95fd2373 100644 --- a/ohkami/src/request/mod.rs +++ b/ohkami/src/request/mod.rs @@ -9,6 +9,8 @@ pub(crate) use queries::QueryParams; mod headers; pub use headers::Headers as RequestHeaders; +#[allow(unused)] +pub use headers::Header as RequestHeader; mod memory; pub(crate) use memory::Store; @@ -35,10 +37,6 @@ use { std::pin::Pin, std::borrow::Cow, }; -#[cfg(any(feature="rt_tokio",feature="rt_async-std",feature="rt_worker"))] -pub use { - headers::Header as RequestHeader, -}; #[cfg(feature="websocket")] use crate::websocket::UpgradeID; @@ -164,69 +162,74 @@ impl Request { pub(crate) async fn read( mut self: Pin<&mut Self>, stream: &mut (impl AsyncReader + Unpin), - ) -> Option<()> { + ) -> Option> { + use crate::Response; + if stream.read(&mut *self.__buf__).await.ok()? == 0 {return None}; - let mut r = Reader::new(&*self.__buf__); + let mut r = Reader::new(unsafe { + // pass detouched bytes + // to resolve immutable/mutable borrowing + // + // SAFETY: `self.__buf__` itself is immutable + Slice::from_bytes(&*self.__buf__).as_bytes() + }); - let method = Method::from_bytes(r.read_while(|b| b != &b' '))?; - r.consume(" ").unwrap(); + self.method = Method::from_bytes(r.read_while(|b| b != &b' '))?; + if r.consume(" ").is_none() { + return Some(Err((|| Response::BadRequest())())) + } - let path = unsafe {// SAFETY: Just calling for request bytes - Path::from_request_bytes(r.read_while(|b| b != &b'?' && b != &b' ')) + self.path = match Path::from_request_bytes(r.read_while(|b| b != &b'?' && b != &b' ')) { + Ok(path) => path, + Err(res) => return Some(Err(res)) }; - let query = (r.consume_oneof([" ", "?"]).unwrap() == 1) - .then(|| { - let q = QueryParams::new(r.read_while(|b| b != &b' ')); - #[cfg(debug_assertions)] { - r.consume(" ").unwrap(); - } #[cfg(not(debug_assertions))] { - r.advance_by(1) - } - q - }); + if r.consume_oneof([" ", "?"]).unwrap() == 1 { + self.query = QueryParams::new(r.read_while(|b| b != &b' ')); + r.advance_by(1); + } - r.consume("HTTP/1.1\r\n").expect("Ohkami can only handle HTTP/1.1"); + if r.consume("HTTP/1.1\r\n").is_none() { + return Some(Err((|| Response::HTTPVersionNotSupported())())) + } - let mut headers = RequestHeaders::init(); while r.consume("\r\n").is_none() { let key_bytes = r.read_while(|b| b != &b':'); - r.consume(": ").unwrap(); + if r.consume(": ").is_none() { + return Some(Err((|| Response::BadRequest())())) + } if let Some(key) = RequestHeader::from_bytes(key_bytes) { - headers.insert(key, CowSlice::Ref( + self.headers.insert(key, CowSlice::Ref( Slice::from_bytes(r.read_while(|b| b != &b'\r')) )); } else { - headers.insert_custom( + self.headers.insert_custom( CowSlice::Ref(Slice::from_bytes(key_bytes)), CowSlice::Ref(Slice::from_bytes(r.read_while(|b| b != &b'\r'))) ); } - r.consume("\r\n"); + if r.consume("\r\n").is_none() { + return Some(Err((|| Response::BadRequest())())) + } } - let content_length = headers.ContentLength() - .unwrap_or("") - .as_bytes().into_iter() - .fold(0, |len, b| 10*len + (*b - b'0') as usize); + let content_length = match self.headers.get_raw(RequestHeader::ContentLength) { + Some(v) => unsafe {v.as_bytes()}.into_iter().fold(0, |len, b| 10*len + (*b - b'0') as usize), + None => 0, + }; + if content_length > PAYLOAD_LIMIT { + return Some(Err((|| Response::PayloadTooLarge())())) + } - let payload = if content_length > 0 { - Some(Request::read_payload( + if content_length > 0 { + self.payload = Some(Request::read_payload( stream, r.remaining(), - content_length.min(PAYLOAD_LIMIT), - ).await) - } else {None}; + content_length, + ).await); + } - Some({ - self.method = method; - self.path = path; - if let Some(query) = query { - self.query = query - }; - self.headers = headers; - self.payload = payload; - }) + Some(Ok(())) } #[cfg(any(feature="rt_tokio", feature="rt_async-std"))] @@ -264,7 +267,7 @@ impl Request { #[cfg(feature="testing")] pub(crate) async fn read(mut self: Pin<&mut Self>, raw_bytes: &mut &[u8] - ) -> Option<()> { + ) -> Option> { let mut r = Reader::new(raw_bytes); self.method = Method::from_bytes(r.read_while(|b| b != &b' '))?; @@ -277,7 +280,7 @@ impl Request { }); // SAFETY: Just calling for request bytes and `self.__url__` is already initialized unsafe {let __url__ = self.__url__.assume_init_ref(); - let path = Path::from_request_bytes(__url__.path().as_bytes()); + let path = Path::from_request_bytes(__url__.path().as_bytes()).unwrap(); let query = __url__.query().map(|str| QueryParams::new(str.as_bytes())); self.path = path; if let Some(query) = query { @@ -304,16 +307,16 @@ impl Request { } self.payload = { - let content_length = self.headers.ContentLength() - .unwrap_or("") - .as_bytes().into_iter() - .fold(0, |len, b| 10*len + (*b - b'0') as usize); + let content_length = match self.headers.get_raw(RequestHeader::ContentLength) { + Some(v) => unsafe {v.as_bytes()}.into_iter().fold(0, |len, b| 10*len + (*b - b'0') as usize), + None => 0, + }; (content_length > 0).then_some(CowSlice::Own( r.remaining().into() )) }; - Some(()) + Some(Ok(())) } #[cfg(feature="rt_worker")] @@ -335,7 +338,7 @@ impl Request { // SAFETY: Just calling for request bytes and `self.__url__` is already initialized unsafe {let __url__ = self.__url__.assume_init_ref(); - let path = Path::from_request_bytes(__url__.path().as_bytes()); + let path = Path::from_request_bytes(__url__.path().as_bytes()).unwrap(); let query = __url__.query().map(|str| QueryParams::new(str.as_bytes())); self.path = path; if let Some(query) = query { diff --git a/ohkami/src/request/path.rs b/ohkami/src/request/path.rs index 1752163f..be11c304 100644 --- a/ohkami/src/request/path.rs +++ b/ohkami/src/request/path.rs @@ -18,7 +18,7 @@ impl Path { } } - #[inline] pub(crate) unsafe fn from_request_bytes(bytes: &[u8]) -> Self { + #[inline] pub(crate) fn from_request_bytes(bytes: &[u8]) -> Result { #[cfg(debug_assertions)] debug_assert! { bytes.starts_with(b"/") @@ -34,12 +34,12 @@ impl Path { returns `b"/"` if that bytes is `b"/"`. */ let mut len = bytes.len(); - if *bytes.get_unchecked(len-1) == b'/' {len -= 1}; + if *unsafe {bytes.get_unchecked(len-1)} == b'/' {len -= 1}; - Self { + Ok(Self { raw: Slice::new_unchecked(bytes.as_ptr(), len), params: List::new(), - } + }) } #[inline] pub(crate) fn push_param(&mut self, param: Slice) { diff --git a/ohkami/src/session/mod.rs b/ohkami/src/session/mod.rs index 2ab71956..5846bd9a 100644 --- a/ohkami/src/session/mod.rs +++ b/ohkami/src/session/mod.rs @@ -9,13 +9,13 @@ use crate::{Request, Response}; pub(crate) struct Session { - router: Arc, - connection: TcpStream, + router: Arc, + connection: TcpStream, } impl Session { pub(crate) fn new( - router: Arc, - connection: TcpStream, + router: Arc, + connection: TcpStream, ) -> Self { Self { router, @@ -40,17 +40,21 @@ impl Session { loop { let mut req = Request::init(); let mut req = unsafe {Pin::new_unchecked(&mut req)}; - if req.as_mut().read(connection).await.is_none() {break} - - let close = req.headers.Connection().is_some_and(|c| c == "close"); - - let res = match catch_unwind(AssertUnwindSafe(|| self.router.handle(req.get_mut()))) { - Ok(future) => future.await, - Err(panic) => panicking(panic), + match req.as_mut().read(connection).await { + Some(Ok(())) => { + let close = req.headers.Connection() == Some("close"); + let res = match catch_unwind(AssertUnwindSafe(|| self.router.handle(req.get_mut()))) { + Ok(future) => future.await, + Err(panic) => panicking(panic), + }; + res.send(connection).await; + if close {break} + } + Some(Err(res)) => { + res.send(connection).await + } + None => break }; - res.send(connection).await; - - if close {break} } } } diff --git a/ohkami/src/testing/mod.rs b/ohkami/src/testing/mod.rs index a11bbe8d..d1a01ca8 100644 --- a/ohkami/src/testing/mod.rs +++ b/ohkami/src/testing/mod.rs @@ -57,9 +57,11 @@ impl TestingOhkami { let res = async move { let mut request = Request::init(); let mut request = unsafe {Pin::new_unchecked(&mut request)}; - request.as_mut().read(&mut &req.encode()[..]).await; - - let res = router.handle(&mut request).await; + + let res = match request.as_mut().read(&mut &req.encode()[..]).await.unwrap() { + Ok(()) => router.handle(&mut request).await, + Err(res) => res, + }; TestResponse::new(res) };