Skip to content

Commit

Permalink
Improve request parsing (#131)
Browse files Browse the repository at this point in the history
* @2024-04-24 17:27+9:00

* Update session to handle errors in request parsing

* not DEBUG

* improve session

* Clear warnings
  • Loading branch information
kana-rus committed Apr 25, 2024
1 parent 5e90774 commit 1af1e1c
Show file tree
Hide file tree
Showing 7 changed files with 89 additions and 74 deletions.
2 changes: 1 addition & 1 deletion examples/hello/src/main.rs
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@ mod health_handler {


mod hello_handler {
use ohkami::{Response, Status};
use ohkami::Response;
use ohkami::typed::{Payload, Query};
use ohkami::builtin::payload::JSON;

Expand Down
1 change: 1 addition & 0 deletions ohkami/src/request/_test_headers.rs
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
#![cfg(any(feature="testing", feature="DEBUG"))]
#![cfg(any(feature="rt_tokio",feature="rt_async-std",feature="rt_worker"))]

use std::borrow::Cow;
Expand Down
5 changes: 5 additions & 0 deletions ohkami/src/request/headers.rs
Original file line number Diff line number Diff line change
Expand Up @@ -334,6 +334,11 @@ impl Headers {

#[cfg(any(feature="rt_tokio",feature="rt_async-std",feature="rt_worker"))]
impl Headers {
#[allow(unused)]
#[inline] pub(crate) fn get_raw(&self, name: Header) -> Option<&CowSlice> {
unsafe {self.standard.get_unchecked(name as usize)}.as_ref()
}

#[inline] pub(crate) fn insert_custom(&mut self, name: CowSlice, value: CowSlice) {
match &mut self.custom {
Some(c) => {c.insert(name, value);}
Expand Down
107 changes: 55 additions & 52 deletions ohkami/src/request/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,8 @@ pub(crate) use queries::QueryParams;

mod headers;
pub use headers::Headers as RequestHeaders;
#[allow(unused)]
pub use headers::Header as RequestHeader;

mod memory;
pub(crate) use memory::Store;
Expand All @@ -35,10 +37,6 @@ use {
std::pin::Pin,
std::borrow::Cow,
};
#[cfg(any(feature="rt_tokio",feature="rt_async-std",feature="rt_worker"))]
pub use {
headers::Header as RequestHeader,
};

#[cfg(feature="websocket")]
use crate::websocket::UpgradeID;
Expand Down Expand Up @@ -164,69 +162,74 @@ impl Request {
pub(crate) async fn read(
mut self: Pin<&mut Self>,
stream: &mut (impl AsyncReader + Unpin),
) -> Option<()> {
) -> Option<Result<(), crate::Response>> {
use crate::Response;

if stream.read(&mut *self.__buf__).await.ok()? == 0 {return None};
let mut r = Reader::new(&*self.__buf__);
let mut r = Reader::new(unsafe {
// pass detouched bytes
// to resolve immutable/mutable borrowing
//
// SAFETY: `self.__buf__` itself is immutable
Slice::from_bytes(&*self.__buf__).as_bytes()
});

let method = Method::from_bytes(r.read_while(|b| b != &b' '))?;
r.consume(" ").unwrap();
self.method = Method::from_bytes(r.read_while(|b| b != &b' '))?;
if r.consume(" ").is_none() {
return Some(Err((|| Response::BadRequest())()))
}

let path = unsafe {// SAFETY: Just calling for request bytes
Path::from_request_bytes(r.read_while(|b| b != &b'?' && b != &b' '))
self.path = match Path::from_request_bytes(r.read_while(|b| b != &b'?' && b != &b' ')) {
Ok(path) => path,
Err(res) => return Some(Err(res))
};

let query = (r.consume_oneof([" ", "?"]).unwrap() == 1)
.then(|| {
let q = QueryParams::new(r.read_while(|b| b != &b' '));
#[cfg(debug_assertions)] {
r.consume(" ").unwrap();
} #[cfg(not(debug_assertions))] {
r.advance_by(1)
}
q
});
if r.consume_oneof([" ", "?"]).unwrap() == 1 {
self.query = QueryParams::new(r.read_while(|b| b != &b' '));
r.advance_by(1);
}

r.consume("HTTP/1.1\r\n").expect("Ohkami can only handle HTTP/1.1");
if r.consume("HTTP/1.1\r\n").is_none() {
return Some(Err((|| Response::HTTPVersionNotSupported())()))
}

let mut headers = RequestHeaders::init();
while r.consume("\r\n").is_none() {
let key_bytes = r.read_while(|b| b != &b':');
r.consume(": ").unwrap();
if r.consume(": ").is_none() {
return Some(Err((|| Response::BadRequest())()))
}
if let Some(key) = RequestHeader::from_bytes(key_bytes) {
headers.insert(key, CowSlice::Ref(
self.headers.insert(key, CowSlice::Ref(
Slice::from_bytes(r.read_while(|b| b != &b'\r'))
));
} else {
headers.insert_custom(
self.headers.insert_custom(
CowSlice::Ref(Slice::from_bytes(key_bytes)),
CowSlice::Ref(Slice::from_bytes(r.read_while(|b| b != &b'\r')))
);
}
r.consume("\r\n");
if r.consume("\r\n").is_none() {
return Some(Err((|| Response::BadRequest())()))
}
}

let content_length = headers.ContentLength()
.unwrap_or("")
.as_bytes().into_iter()
.fold(0, |len, b| 10*len + (*b - b'0') as usize);
let content_length = match self.headers.get_raw(RequestHeader::ContentLength) {
Some(v) => unsafe {v.as_bytes()}.into_iter().fold(0, |len, b| 10*len + (*b - b'0') as usize),
None => 0,
};
if content_length > PAYLOAD_LIMIT {
return Some(Err((|| Response::PayloadTooLarge())()))
}

let payload = if content_length > 0 {
Some(Request::read_payload(
if content_length > 0 {
self.payload = Some(Request::read_payload(
stream,
r.remaining(),
content_length.min(PAYLOAD_LIMIT),
).await)
} else {None};
content_length,
).await);
}

Some({
self.method = method;
self.path = path;
if let Some(query) = query {
self.query = query
};
self.headers = headers;
self.payload = payload;
})
Some(Ok(()))
}

#[cfg(any(feature="rt_tokio", feature="rt_async-std"))]
Expand Down Expand Up @@ -264,7 +267,7 @@ impl Request {
#[cfg(feature="testing")]
pub(crate) async fn read(mut self: Pin<&mut Self>,
raw_bytes: &mut &[u8]
) -> Option<()> {
) -> Option<Result<(), crate::Response>> {
let mut r = Reader::new(raw_bytes);

self.method = Method::from_bytes(r.read_while(|b| b != &b' '))?;
Expand All @@ -277,7 +280,7 @@ impl Request {
});
// SAFETY: Just calling for request bytes and `self.__url__` is already initialized
unsafe {let __url__ = self.__url__.assume_init_ref();
let path = Path::from_request_bytes(__url__.path().as_bytes());
let path = Path::from_request_bytes(__url__.path().as_bytes()).unwrap();
let query = __url__.query().map(|str| QueryParams::new(str.as_bytes()));
self.path = path;
if let Some(query) = query {
Expand All @@ -304,16 +307,16 @@ impl Request {
}

self.payload = {
let content_length = self.headers.ContentLength()
.unwrap_or("")
.as_bytes().into_iter()
.fold(0, |len, b| 10*len + (*b - b'0') as usize);
let content_length = match self.headers.get_raw(RequestHeader::ContentLength) {
Some(v) => unsafe {v.as_bytes()}.into_iter().fold(0, |len, b| 10*len + (*b - b'0') as usize),
None => 0,
};
(content_length > 0).then_some(CowSlice::Own(
r.remaining().into()
))
};

Some(())
Some(Ok(()))
}

#[cfg(feature="rt_worker")]
Expand All @@ -335,7 +338,7 @@ impl Request {

// SAFETY: Just calling for request bytes and `self.__url__` is already initialized
unsafe {let __url__ = self.__url__.assume_init_ref();
let path = Path::from_request_bytes(__url__.path().as_bytes());
let path = Path::from_request_bytes(__url__.path().as_bytes()).unwrap();
let query = __url__.query().map(|str| QueryParams::new(str.as_bytes()));
self.path = path;
if let Some(query) = query {
Expand Down
8 changes: 4 additions & 4 deletions ohkami/src/request/path.rs
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ impl Path {
}
}

#[inline] pub(crate) unsafe fn from_request_bytes(bytes: &[u8]) -> Self {
#[inline] pub(crate) fn from_request_bytes(bytes: &[u8]) -> Result<Self, crate::Response> {
#[cfg(debug_assertions)]
debug_assert! {
bytes.starts_with(b"/")
Expand All @@ -34,12 +34,12 @@ impl Path {
returns `b"/"` if that bytes is `b"/"`.
*/
let mut len = bytes.len();
if *bytes.get_unchecked(len-1) == b'/' {len -= 1};
if *unsafe {bytes.get_unchecked(len-1)} == b'/' {len -= 1};

Self {
Ok(Self {
raw: Slice::new_unchecked(bytes.as_ptr(), len),
params: List::new(),
}
})
}

#[inline] pub(crate) fn push_param(&mut self, param: Slice) {
Expand Down
32 changes: 18 additions & 14 deletions ohkami/src/session/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -9,13 +9,13 @@ use crate::{Request, Response};


pub(crate) struct Session {
router: Arc<RadixRouter>,
connection: TcpStream,
router: Arc<RadixRouter>,
connection: TcpStream,
}
impl Session {
pub(crate) fn new(
router: Arc<RadixRouter>,
connection: TcpStream,
router: Arc<RadixRouter>,
connection: TcpStream,
) -> Self {
Self {
router,
Expand All @@ -40,17 +40,21 @@ impl Session {
loop {
let mut req = Request::init();
let mut req = unsafe {Pin::new_unchecked(&mut req)};
if req.as_mut().read(connection).await.is_none() {break}

let close = req.headers.Connection().is_some_and(|c| c == "close");

let res = match catch_unwind(AssertUnwindSafe(|| self.router.handle(req.get_mut()))) {
Ok(future) => future.await,
Err(panic) => panicking(panic),
match req.as_mut().read(connection).await {
Some(Ok(())) => {
let close = req.headers.Connection() == Some("close");
let res = match catch_unwind(AssertUnwindSafe(|| self.router.handle(req.get_mut()))) {
Ok(future) => future.await,
Err(panic) => panicking(panic),
};
res.send(connection).await;
if close {break}
}
Some(Err(res)) => {
res.send(connection).await
}
None => break
};
res.send(connection).await;

if close {break}
}
}
}
8 changes: 5 additions & 3 deletions ohkami/src/testing/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -57,9 +57,11 @@ impl TestingOhkami {
let res = async move {
let mut request = Request::init();
let mut request = unsafe {Pin::new_unchecked(&mut request)};
request.as_mut().read(&mut &req.encode()[..]).await;

let res = router.handle(&mut request).await;

let res = match request.as_mut().read(&mut &req.encode()[..]).await.unwrap() {
Ok(()) => router.handle(&mut request).await,
Err(res) => res,
};

TestResponse::new(res)
};
Expand Down

0 comments on commit 1af1e1c

Please sign in to comment.