From a6758fd1f9555800fdc65e60add7a5336ec2c7e0 Mon Sep 17 00:00:00 2001 From: Tim Kuehn Date: Fri, 29 Dec 2023 00:11:41 -0800 Subject: [PATCH] BeforeRequest hook chaining. It's unintuitive that serve.before(hook1).before(hook2) executes in reverse order, with hook2 going before hook1. With BeforeRequestList, users can write `before().then(hook1).then(hook2).serving(serve)`, and it will run hook1, then hook2, then the service fn. --- tarpc/examples/tracing.rs | 29 ++-- tarpc/src/server.rs | 14 +- tarpc/src/server/request_hook.rs | 9 +- tarpc/src/server/request_hook/after.rs | 10 +- tarpc/src/server/request_hook/before.rs | 161 ++++++++++++++++-- .../server/request_hook/before_and_after.rs | 12 +- 6 files changed, 187 insertions(+), 48 deletions(-) diff --git a/tarpc/examples/tracing.rs b/tarpc/examples/tracing.rs index 8cb233cd..0ccc2b37 100644 --- a/tarpc/examples/tracing.rs +++ b/tarpc/examples/tracing.rs @@ -25,7 +25,8 @@ use tarpc::{ context, serde_transport, server::{ incoming::{spawn_incoming, Incoming}, - BaseChannel, Serve, + request_hook::{self, BeforeRequestList}, + BaseChannel, }, tokio_serde::formats::Json, ClientMessage, Response, ServerError, Transport, @@ -141,19 +142,21 @@ async fn main() -> anyhow::Result<()> { let (add_listener1, addr1) = listen_on_random_port().await?; let (add_listener2, addr2) = listen_on_random_port().await?; let something_bad_happened = Arc::new(AtomicBool::new(false)); - let server = AddServer.serve().before(move |_: &mut _, _: &_| { - let something_bad_happened = something_bad_happened.clone(); - async move { - if something_bad_happened.fetch_xor(true, Ordering::Relaxed) { - Err(ServerError::new( - io::ErrorKind::NotFound, - "Gamma Ray!".into(), - )) - } else { - Ok(()) + let server = request_hook::before() + .then_fn(move |_: &mut _, _: &_| { + let something_bad_happened = something_bad_happened.clone(); + async move { + if something_bad_happened.fetch_xor(true, Ordering::Relaxed) { + Err(ServerError::new( + io::ErrorKind::NotFound, + "Gamma Ray!".into(), + )) + } else { + Ok(()) + } } - } - }); + }) + .serving(AddServer.serve()); let add_server = add_listener1 .chain(add_listener2) .map(BaseChannel::with_defaults); diff --git a/tarpc/src/server.rs b/tarpc/src/server.rs index aafa8766..68dd3b06 100644 --- a/tarpc/src/server.rs +++ b/tarpc/src/server.rs @@ -36,7 +36,7 @@ pub mod limits; pub mod incoming; use request_hook::{ - AfterRequest, AfterRequestHook, BeforeAndAfterRequestHook, BeforeRequest, BeforeRequestHook, + AfterRequest, BeforeRequest, HookThenServe, HookThenServeThenHook, ServeThenHook, }; /// Settings that control the behavior of [channels](Channel). @@ -116,12 +116,12 @@ pub trait Serve { /// let response = serve.serve(context::current(), 1); /// assert!(block_on(response).is_err()); /// ``` - fn before(self, hook: Hook) -> BeforeRequestHook + fn before(self, hook: Hook) -> HookThenServe where Hook: BeforeRequest, Self: Sized, { - BeforeRequestHook::new(self, hook) + HookThenServe::new(self, hook) } /// Runs a hook after completion of a request. @@ -159,12 +159,12 @@ pub trait Serve { /// let response = serve.serve(context::current(), 1); /// assert!(block_on(response).is_err()); /// ``` - fn after(self, hook: Hook) -> AfterRequestHook + fn after(self, hook: Hook) -> ServeThenHook where Hook: AfterRequest, Self: Sized, { - AfterRequestHook::new(self, hook) + ServeThenHook::new(self, hook) } /// Runs a hook before and after execution of the request. @@ -212,12 +212,12 @@ pub trait Serve { fn before_and_after( self, hook: Hook, - ) -> BeforeAndAfterRequestHook + ) -> HookThenServeThenHook where Hook: BeforeRequest + AfterRequest, Self: Sized, { - BeforeAndAfterRequestHook::new(self, hook) + HookThenServeThenHook::new(self, hook) } } diff --git a/tarpc/src/server/request_hook.rs b/tarpc/src/server/request_hook.rs index ef23d73b..524134ad 100644 --- a/tarpc/src/server/request_hook.rs +++ b/tarpc/src/server/request_hook.rs @@ -16,7 +16,10 @@ mod after; mod before_and_after; pub use { - after::{AfterRequest, AfterRequestHook}, - before::{BeforeRequest, BeforeRequestHook}, - before_and_after::BeforeAndAfterRequestHook, + after::{AfterRequest, ServeThenHook}, + before::{ + before, BeforeRequest, BeforeRequestCons, BeforeRequestList, BeforeRequestNil, + HookThenServe, + }, + before_and_after::HookThenServeThenHook, }; diff --git a/tarpc/src/server/request_hook/after.rs b/tarpc/src/server/request_hook/after.rs index 7f6f1c56..59afb473 100644 --- a/tarpc/src/server/request_hook/after.rs +++ b/tarpc/src/server/request_hook/after.rs @@ -29,18 +29,18 @@ where } /// A Service function that runs a hook after request execution. -pub struct AfterRequestHook { +pub struct ServeThenHook { serve: Serv, hook: Hook, } -impl AfterRequestHook { +impl ServeThenHook { pub(crate) fn new(serve: Serv, hook: Hook) -> Self { Self { serve, hook } } } -impl Clone for AfterRequestHook { +impl Clone for ServeThenHook { fn clone(&self) -> Self { Self { serve: self.serve.clone(), @@ -49,7 +49,7 @@ impl Clone for AfterRequestHook { } } -impl Serve for AfterRequestHook +impl Serve for ServeThenHook where Serv: Serve, Hook: AfterRequest, @@ -62,7 +62,7 @@ where mut ctx: context::Context, req: Serv::Req, ) -> Result { - let AfterRequestHook { + let ServeThenHook { serve, mut hook, .. } = self; let mut resp = serve.serve(ctx, req).await; diff --git a/tarpc/src/server/request_hook/before.rs b/tarpc/src/server/request_hook/before.rs index f7d56aad..a221219e 100644 --- a/tarpc/src/server/request_hook/before.rs +++ b/tarpc/src/server/request_hook/before.rs @@ -22,6 +22,38 @@ pub trait BeforeRequest { async fn before(&mut self, ctx: &mut context::Context, req: &Req) -> Result<(), ServerError>; } +/// A list of hooks that run in order before request execution. +pub trait BeforeRequestList: BeforeRequest { + /// The hook returned by `BeforeRequestList::then`. + type Then: BeforeRequest + where + Next: BeforeRequest; + + /// Returns a hook that, when run, runs two hooks, first `self` and then `next`. + fn then>(self, next: Next) -> Self::Then; + + /// Same as `then`, but helps the compiler with type inference when Next is a closure. + fn then_fn< + Next: FnMut(&mut context::Context, &Req) -> Fut, + Fut: Future>, + >( + self, + next: Next, + ) -> Self::Then + where + Self: Sized, + { + self.then(next) + } + + /// The service fn returned by `BeforeRequestList::serving`. + type Serve>: Serve; + + /// Runs the list of request hooks before execution of the given serve fn. + /// This is equivalent to `serve.before(before_request_chain)` but may be syntactically nicer. + fn serving>(self, serve: S) -> Self::Serve; +} + impl BeforeRequest for F where F: FnMut(&mut context::Context, &Req) -> Fut, @@ -33,27 +65,19 @@ where } /// A Service function that runs a hook before request execution. -pub struct BeforeRequestHook { +#[derive(Clone)] +pub struct HookThenServe { serve: Serv, hook: Hook, } -impl BeforeRequestHook { +impl HookThenServe { pub(crate) fn new(serve: Serv, hook: Hook) -> Self { Self { serve, hook } } } -impl Clone for BeforeRequestHook { - fn clone(&self) -> Self { - Self { - serve: self.serve.clone(), - hook: self.hook.clone(), - } - } -} - -impl Serve for BeforeRequestHook +impl Serve for HookThenServe where Serv: Serve, Hook: BeforeRequest, @@ -66,10 +90,121 @@ where mut ctx: context::Context, req: Self::Req, ) -> Result { - let BeforeRequestHook { + let HookThenServe { serve, mut hook, .. } = self; hook.before(&mut ctx, &req).await?; serve.serve(ctx, req).await } } + +/// Returns a request hook builder that runs a series of hooks before request execution. +/// +/// Example +/// +/// ```rust +/// use futures::{executor::block_on, future}; +/// use tarpc::{context, ServerError, server::{Serve, serve, request_hook::{self, +/// BeforeRequest, BeforeRequestList}}}; +/// use std::{cell::Cell, io}; +/// +/// let i = Cell::new(0); +/// let serve = request_hook::before() +/// .then_fn(|_, _| async { +/// assert!(i.get() == 0); +/// i.set(1); +/// Ok(()) +/// }) +/// .then_fn(|_, _| async { +/// assert!(i.get() == 1); +/// i.set(2); +/// Ok(()) +/// }) +/// .serving(serve(|_ctx, i| async move { Ok(i + 1) })); +/// let response = serve.clone().serve(context::current(), 1); +/// assert!(block_on(response).is_ok()); +/// assert!(i.get() == 2); +/// ``` +pub fn before() -> BeforeRequestNil { + BeforeRequestNil +} + +/// A list of hooks that run in order before a request is executed. +#[derive(Clone, Copy)] +pub struct BeforeRequestCons(First, Rest); + +/// A noop hook that runs before a request is executed. +#[derive(Clone, Copy)] +pub struct BeforeRequestNil; + +impl, Rest: BeforeRequest> BeforeRequest + for BeforeRequestCons +{ + async fn before(&mut self, ctx: &mut context::Context, req: &Req) -> Result<(), ServerError> { + let BeforeRequestCons(first, rest) = self; + first.before(ctx, req).await?; + rest.before(ctx, req).await?; + Ok(()) + } +} + +impl BeforeRequest for BeforeRequestNil { + async fn before(&mut self, _: &mut context::Context, _: &Req) -> Result<(), ServerError> { + Ok(()) + } +} + +impl, Rest: BeforeRequestList> BeforeRequestList + for BeforeRequestCons +{ + type Then = BeforeRequestCons> where Next: BeforeRequest; + + fn then>(self, next: Next) -> Self::Then { + let BeforeRequestCons(first, rest) = self; + BeforeRequestCons(first, rest.then(next)) + } + + type Serve> = HookThenServe; + + fn serving>(self, serve: S) -> Self::Serve { + HookThenServe::new(serve, self) + } +} + +impl BeforeRequestList for BeforeRequestNil { + type Then = BeforeRequestCons where Next: BeforeRequest; + + fn then>(self, next: Next) -> Self::Then { + BeforeRequestCons(next, BeforeRequestNil) + } + + type Serve> = S; + + fn serving>(self, serve: S) -> S { + serve + } +} + +#[test] +fn before_request_list() { + use crate::server::serve; + use futures::executor::block_on; + use std::cell::Cell; + + let i = Cell::new(0); + let serve = before() + .then_fn(|_, _| async { + assert!(i.get() == 0); + i.set(1); + Ok(()) + }) + .then_fn(|_, _| async { + assert!(i.get() == 1); + i.set(2); + Ok(()) + }) + .serving(serve(|_ctx, i| async move { Ok(i + 1) })); + let response = serve.clone().serve(context::current(), 1); + assert!(block_on(response).is_ok()); + assert!(i.get() == 2); +} diff --git a/tarpc/src/server/request_hook/before_and_after.rs b/tarpc/src/server/request_hook/before_and_after.rs index ff61a53e..995ddea8 100644 --- a/tarpc/src/server/request_hook/before_and_after.rs +++ b/tarpc/src/server/request_hook/before_and_after.rs @@ -11,13 +11,13 @@ use crate::{context, server::Serve, ServerError}; use std::marker::PhantomData; /// A Service function that runs a hook both before and after request execution. -pub struct BeforeAndAfterRequestHook { +pub struct HookThenServeThenHook { serve: Serv, hook: Hook, fns: PhantomData<(fn(Req), fn(Resp))>, } -impl BeforeAndAfterRequestHook { +impl HookThenServeThenHook { pub(crate) fn new(serve: Serv, hook: Hook) -> Self { Self { serve, @@ -27,9 +27,7 @@ impl BeforeAndAfterRequestHook { } } -impl Clone - for BeforeAndAfterRequestHook -{ +impl Clone for HookThenServeThenHook { fn clone(&self) -> Self { Self { serve: self.serve.clone(), @@ -39,7 +37,7 @@ impl Clone } } -impl Serve for BeforeAndAfterRequestHook +impl Serve for HookThenServeThenHook where Serv: Serve, Hook: BeforeRequest + AfterRequest, @@ -48,7 +46,7 @@ where type Resp = Resp; async fn serve(self, mut ctx: context::Context, req: Req) -> Result { - let BeforeAndAfterRequestHook { + let HookThenServeThenHook { serve, mut hook, .. } = self; hook.before(&mut ctx, &req).await?;