diff --git a/Cargo.lock b/Cargo.lock index 77b2abf8..1d5aa479 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -2997,7 +2997,7 @@ dependencies = [ [[package]] name = "volo-http" -version = "0.1.8" +version = "0.1.9" dependencies = [ "bytes", "cookie", diff --git a/examples/src/http/simple.rs b/examples/src/http/simple.rs index 4e938298..f5c05058 100644 --- a/examples/src/http/simple.rs +++ b/examples/src/http/simple.rs @@ -7,7 +7,7 @@ use volo_http::{ extension::Extension, extract::{Form, Query}, http::header, - layer::TimeoutLayer, + layer::{FilterLayer, TimeoutLayer}, middleware::{self, Next}, response::IntoResponse, route::{from_handler, get, post, service_fn, MethodRouter, Router}, @@ -172,6 +172,14 @@ fn test_router() -> Router { .get(service_fn(service_fn_test)) .build(), ) + // curl -v http://127.0.0.1:8080/test/anyaddr?reject_me + .layer(FilterLayer::new(|uri: Uri| async move { + if uri.query().is_some() && uri.query().unwrap() == "reject_me" { + Err(StatusCode::INTERNAL_SERVER_ERROR) + } else { + Ok(()) + } + })) } // You can use the following commands for testing cookies diff --git a/volo-http/Cargo.toml b/volo-http/Cargo.toml index 2740e18d..8450ba3b 100644 --- a/volo-http/Cargo.toml +++ b/volo-http/Cargo.toml @@ -1,6 +1,6 @@ [package] name = "volo-http" -version = "0.1.8" +version = "0.1.9" edition.workspace = true homepage.workspace = true repository.workspace = true diff --git a/volo-http/src/handler.rs b/volo-http/src/handler.rs index ad74c397..d18e5c5c 100644 --- a/volo-http/src/handler.rs +++ b/volo-http/src/handler.rs @@ -251,18 +251,17 @@ where } } -pub trait HandlerWithoutRequest: Sized { - fn call(self, cx: &mut HttpContext) -> impl Future + Send; +pub trait HandlerWithoutRequest: Sized { + fn call(self, cx: &mut HttpContext) -> impl Future> + Send; } -impl HandlerWithoutRequest<()> for F +impl HandlerWithoutRequest<(), Ret> for F where F: FnOnce() -> Fut + Clone + Send, - Fut: Future + Send, - Res: IntoResponse, + Fut: Future + Send, { - async fn call(self, _context: &mut HttpContext) -> Response { - self().await.into_response() + async fn call(self, _context: &mut HttpContext) -> Result { + Ok(self().await) } } @@ -271,21 +270,20 @@ macro_rules! impl_handler_without_request { $($ty:ident),* $(,)? ) => { #[allow(non_snake_case, unused_mut, unused_variables)] - impl HandlerWithoutRequest<($($ty,)*)> for F + impl HandlerWithoutRequest<($($ty,)*), Ret> for F where F: FnOnce($($ty,)*) -> Fut + Clone + Send, - Fut: Future + Send, - Res: IntoResponse, + Fut: Future + Send, $( for<'r> $ty: FromContext<()> + Send + 'r, )* { - async fn call(self, cx: &mut HttpContext) -> Response { + async fn call(self, cx: &mut HttpContext) -> Result { $( let $ty = match $ty::from_context(cx, &()).await { Ok(value) => value, - Err(rejection) => return rejection.into_response(), + Err(rejection) => return Err(rejection.into_response()), }; )* - self($($ty,)*).await.into_response() + Ok(self($($ty,)*).await) } } }; diff --git a/volo-http/src/layer.rs b/volo-http/src/layer.rs index 417f2989..56a9fab5 100644 --- a/volo-http/src/layer.rs +++ b/volo-http/src/layer.rs @@ -1,72 +1,62 @@ -use std::{marker::PhantomData, time::Duration}; +use std::{convert::Infallible, marker::PhantomData, time::Duration}; -use hyper::{ - body::Incoming, - http::{Method, StatusCode}, -}; +use hyper::body::Incoming; use motore::{layer::Layer, service::Service}; use crate::{ handler::HandlerWithoutRequest, - request::Request, response::{IntoResponse, Response}, HttpContext, }; -pub trait LayerExt { - fn method( - self, - method: Method, - ) -> FilterLayer Result<(), StatusCode>>> - where - Self: Sized, - { - self.filter(Box::new(move |cx: &mut HttpContext, _: &Request| { - if cx.method == method { - Ok(()) - } else { - Err(StatusCode::METHOD_NOT_ALLOWED) - } - })) - } - - fn filter(self, f: F) -> FilterLayer - where - Self: Sized, - F: Fn(&mut HttpContext, &Request) -> Result<(), StatusCode>, - { - FilterLayer { f } - } +#[derive(Clone)] +pub struct FilterLayer { + handler: H, + _marker: PhantomData<(R, T)>, } -pub struct FilterLayer { - f: F, +impl FilterLayer { + pub fn new(h: H) -> Self { + Self { + handler: h, + _marker: PhantomData, + } + } } -impl Layer for FilterLayer +impl Layer for FilterLayer where - S: Service + Send + Sync + 'static, - F: Fn(&mut HttpContext, &Request) -> Result<(), StatusCode> + Send + Sync, + S: Send + Sync + 'static, + H: Clone + Send + Sync + 'static, + T: Sync, { - type Service = Filter; + type Service = Filter; fn layer(self, inner: S) -> Self::Service { Filter { service: inner, - f: self.f, + handler: self.handler, + _marker: PhantomData, } } } -pub struct Filter { +#[derive(Clone)] +pub struct Filter { service: S, - f: F, + handler: H, + _marker: PhantomData<(R, T)>, } -impl Service for Filter +impl Service for Filter where - S: Service + Send + Sync + 'static, - F: Fn(&mut HttpContext, &Request) -> Result<(), StatusCode> + Send + Sync, + S: Service + + Send + + Sync + + 'static, + H: HandlerWithoutRequest> + Clone + Send + Sync + 'static, + R: IntoResponse + Send + Sync, + T: Sync, { type Response = S::Response; @@ -75,26 +65,33 @@ where async fn call<'s, 'cx>( &'s self, cx: &'cx mut HttpContext, - req: Request, + req: Incoming, ) -> Result { - if let Err(status) = (self.f)(cx, &req) { - return Ok(status.into_response()); + match self.handler.clone().call(cx).await { + // do not filter it, call the service + Ok(Ok(())) => self.service.call(cx, req).await, + // filter it and return the specified response + Ok(Err(res)) => Ok(res.into_response()), + // something wrong while extracting + Err(rej) => { + tracing::warn!("[VOLO] FilterLayer: something wrong while extracting"); + Ok(rej.into_response()) + } } - self.service.call(cx, req).await } } #[derive(Clone)] -pub struct TimeoutLayer { +pub struct TimeoutLayer { duration: Duration, handler: H, - _marker: PhantomData, + _marker: PhantomData<(R, T)>, } -impl TimeoutLayer { +impl TimeoutLayer { pub fn new(duration: Duration, handler: H) -> Self where - H: HandlerWithoutRequest + Clone + Send + Sync + 'static, + H: Send + Sync + 'static, { Self { duration, @@ -104,13 +101,14 @@ impl TimeoutLayer { } } -impl Layer for TimeoutLayer +impl Layer for TimeoutLayer where - S: Service + Send + Sync + 'static, - H: HandlerWithoutRequest + Clone + Send + Sync + 'static, + S: Send + Sync + 'static, + H: Clone + Send + Sync + 'static, + R: Sync, T: Sync, { - type Service = Timeout; + type Service = Timeout; fn layer(self, inner: S) -> Self::Service { Timeout { @@ -123,18 +121,21 @@ where } #[derive(Clone)] -pub struct Timeout { +pub struct Timeout { service: S, duration: Duration, handler: H, - _marker: PhantomData, + _marker: PhantomData<(R, T)>, } -impl Service for Timeout +impl Service for Timeout where - S: Service + Send + Sync + 'static, - S::Error: Send, - H: HandlerWithoutRequest + Clone + Send + Sync + 'static, + S: Service + + Send + + Sync + + 'static, + H: HandlerWithoutRequest + Clone + Send + Sync + 'static, + R: IntoResponse + Sync, T: Sync, { type Response = S::Response; @@ -152,7 +153,7 @@ where tokio::select! { resp = fut_service => resp, _ = fut_timeout => { - Ok(self.handler.clone().call(cx).await) + Ok(self.handler.clone().call(cx).await.into_response()) }, } } diff --git a/volo-http/src/request.rs b/volo-http/src/request.rs index 48794c13..43d87ccd 100644 --- a/volo-http/src/request.rs +++ b/volo-http/src/request.rs @@ -1,31 +1 @@ -use std::ops::{Deref, DerefMut}; - -use hyper::{body::Incoming, http::request::Builder}; - -pub struct Request(pub(crate) hyper::http::Request); - -impl Request { - pub fn builder() -> Builder { - Builder::new() - } -} - -impl Deref for Request { - type Target = hyper::http::Request; - - fn deref(&self) -> &Self::Target { - &self.0 - } -} - -impl DerefMut for Request { - fn deref_mut(&mut self) -> &mut Self::Target { - &mut self.0 - } -} - -impl From> for Request { - fn from(value: hyper::http::Request) -> Self { - Self(value) - } -} +pub type Request = hyper::http::Request; diff --git a/volo-http/src/response.rs b/volo-http/src/response.rs index 22022079..e4bf080c 100644 --- a/volo-http/src/response.rs +++ b/volo-http/src/response.rs @@ -1,6 +1,5 @@ use std::{ convert::Infallible, - ops::{Deref, DerefMut}, pin::Pin, task::{Context, Poll}, }; @@ -10,42 +9,12 @@ use http_body_util::Full; use hyper::{ body::{Body, Bytes, Frame, SizeHint}, header::HeaderValue, - http::{header::IntoHeaderName, response::Builder, StatusCode}, + http::{header::IntoHeaderName, StatusCode}, HeaderMap, }; use pin_project::pin_project; -pub struct Response(hyper::http::Response); - -impl Response { - pub fn builder() -> Builder { - Builder::new() - } - - pub(crate) fn inner(self) -> hyper::http::Response { - self.0 - } -} - -impl Deref for Response { - type Target = hyper::http::Response; - - fn deref(&self) -> &Self::Target { - &self.0 - } -} - -impl DerefMut for Response { - fn deref_mut(&mut self) -> &mut Self::Target { - &mut self.0 - } -} - -impl From> for Response { - fn from(value: hyper::http::Response) -> Self { - Self(value) - } -} +pub type Response = hyper::http::Response; #[pin_project] pub struct RespBody { @@ -190,7 +159,7 @@ where { fn into_response(self) -> Response { let mut resp = self.1.into_response(); - *resp.0.status_mut() = self.0; + *resp.status_mut() = self.0; resp } } @@ -205,9 +174,13 @@ impl IntoResponse for StatusCode { } } -impl IntoResponse for Response { +impl IntoResponse for hyper::http::Response +where + B: Into, +{ fn into_response(self) -> Response { - self + let (parts, body) = self.into_parts(); + Response::from_parts(parts, body.into()) } } diff --git a/volo-http/src/server.rs b/volo-http/src/server.rs index 12693328..21e905bd 100644 --- a/volo-http/src/server.rs +++ b/volo-http/src/server.rs @@ -267,7 +267,7 @@ async fn handle_conn( Ok(resp) => resp, Err(inf) => inf.into_response(), }; - Ok::, Infallible>(resp.inner()) + Ok::, Infallible>(resp) } }), );