Skip to content

Commit

Permalink
BeforeRequest hook chaining.
Browse files Browse the repository at this point in the history
It's unintuitive that serve.before(hook1).before(hook2) executes in
reverse order, with hook2 going before hook1. With BeforeRequestList,
users can write `before().then(hook1).then(hook2).serving(serve)`, and
it will run hook1, then hook2, then the service fn.
  • Loading branch information
tikue committed Dec 30, 2023
1 parent 2c241cc commit a6758fd
Show file tree
Hide file tree
Showing 6 changed files with 187 additions and 48 deletions.
29 changes: 16 additions & 13 deletions tarpc/examples/tracing.rs
Expand Up @@ -25,7 +25,8 @@ use tarpc::{
context, serde_transport,
server::{
incoming::{spawn_incoming, Incoming},
BaseChannel, Serve,
request_hook::{self, BeforeRequestList},
BaseChannel,
},
tokio_serde::formats::Json,
ClientMessage, Response, ServerError, Transport,
Expand Down Expand Up @@ -141,19 +142,21 @@ async fn main() -> anyhow::Result<()> {
let (add_listener1, addr1) = listen_on_random_port().await?;
let (add_listener2, addr2) = listen_on_random_port().await?;
let something_bad_happened = Arc::new(AtomicBool::new(false));
let server = AddServer.serve().before(move |_: &mut _, _: &_| {
let something_bad_happened = something_bad_happened.clone();
async move {
if something_bad_happened.fetch_xor(true, Ordering::Relaxed) {
Err(ServerError::new(
io::ErrorKind::NotFound,
"Gamma Ray!".into(),
))
} else {
Ok(())
let server = request_hook::before()
.then_fn(move |_: &mut _, _: &_| {
let something_bad_happened = something_bad_happened.clone();
async move {
if something_bad_happened.fetch_xor(true, Ordering::Relaxed) {
Err(ServerError::new(
io::ErrorKind::NotFound,
"Gamma Ray!".into(),
))
} else {
Ok(())
}
}
}
});
})
.serving(AddServer.serve());
let add_server = add_listener1
.chain(add_listener2)
.map(BaseChannel::with_defaults);
Expand Down
14 changes: 7 additions & 7 deletions tarpc/src/server.rs
Expand Up @@ -36,7 +36,7 @@ pub mod limits;
pub mod incoming;

use request_hook::{
AfterRequest, AfterRequestHook, BeforeAndAfterRequestHook, BeforeRequest, BeforeRequestHook,
AfterRequest, BeforeRequest, HookThenServe, HookThenServeThenHook, ServeThenHook,
};

/// Settings that control the behavior of [channels](Channel).
Expand Down Expand Up @@ -116,12 +116,12 @@ pub trait Serve {
/// let response = serve.serve(context::current(), 1);
/// assert!(block_on(response).is_err());
/// ```
fn before<Hook>(self, hook: Hook) -> BeforeRequestHook<Self, Hook>
fn before<Hook>(self, hook: Hook) -> HookThenServe<Self, Hook>
where
Hook: BeforeRequest<Self::Req>,
Self: Sized,
{
BeforeRequestHook::new(self, hook)
HookThenServe::new(self, hook)
}

/// Runs a hook after completion of a request.
Expand Down Expand Up @@ -159,12 +159,12 @@ pub trait Serve {
/// let response = serve.serve(context::current(), 1);
/// assert!(block_on(response).is_err());
/// ```
fn after<Hook>(self, hook: Hook) -> AfterRequestHook<Self, Hook>
fn after<Hook>(self, hook: Hook) -> ServeThenHook<Self, Hook>
where
Hook: AfterRequest<Self::Resp>,
Self: Sized,
{
AfterRequestHook::new(self, hook)
ServeThenHook::new(self, hook)
}

/// Runs a hook before and after execution of the request.
Expand Down Expand Up @@ -212,12 +212,12 @@ pub trait Serve {
fn before_and_after<Hook>(
self,
hook: Hook,
) -> BeforeAndAfterRequestHook<Self::Req, Self::Resp, Self, Hook>
) -> HookThenServeThenHook<Self::Req, Self::Resp, Self, Hook>
where
Hook: BeforeRequest<Self::Req> + AfterRequest<Self::Resp>,
Self: Sized,
{
BeforeAndAfterRequestHook::new(self, hook)
HookThenServeThenHook::new(self, hook)
}
}

Expand Down
9 changes: 6 additions & 3 deletions tarpc/src/server/request_hook.rs
Expand Up @@ -16,7 +16,10 @@ mod after;
mod before_and_after;

pub use {
after::{AfterRequest, AfterRequestHook},
before::{BeforeRequest, BeforeRequestHook},
before_and_after::BeforeAndAfterRequestHook,
after::{AfterRequest, ServeThenHook},
before::{
before, BeforeRequest, BeforeRequestCons, BeforeRequestList, BeforeRequestNil,
HookThenServe,
},
before_and_after::HookThenServeThenHook,
};
10 changes: 5 additions & 5 deletions tarpc/src/server/request_hook/after.rs
Expand Up @@ -29,18 +29,18 @@ where
}

/// A Service function that runs a hook after request execution.
pub struct AfterRequestHook<Serv, Hook> {
pub struct ServeThenHook<Serv, Hook> {
serve: Serv,
hook: Hook,
}

impl<Serv, Hook> AfterRequestHook<Serv, Hook> {
impl<Serv, Hook> ServeThenHook<Serv, Hook> {
pub(crate) fn new(serve: Serv, hook: Hook) -> Self {
Self { serve, hook }
}
}

impl<Serv: Clone, Hook: Clone> Clone for AfterRequestHook<Serv, Hook> {
impl<Serv: Clone, Hook: Clone> Clone for ServeThenHook<Serv, Hook> {
fn clone(&self) -> Self {
Self {
serve: self.serve.clone(),
Expand All @@ -49,7 +49,7 @@ impl<Serv: Clone, Hook: Clone> Clone for AfterRequestHook<Serv, Hook> {
}
}

impl<Serv, Hook> Serve for AfterRequestHook<Serv, Hook>
impl<Serv, Hook> Serve for ServeThenHook<Serv, Hook>
where
Serv: Serve,
Hook: AfterRequest<Serv::Resp>,
Expand All @@ -62,7 +62,7 @@ where
mut ctx: context::Context,
req: Serv::Req,
) -> Result<Serv::Resp, ServerError> {
let AfterRequestHook {
let ServeThenHook {
serve, mut hook, ..
} = self;
let mut resp = serve.serve(ctx, req).await;
Expand Down
161 changes: 148 additions & 13 deletions tarpc/src/server/request_hook/before.rs
Expand Up @@ -22,6 +22,38 @@ pub trait BeforeRequest<Req> {
async fn before(&mut self, ctx: &mut context::Context, req: &Req) -> Result<(), ServerError>;
}

/// A list of hooks that run in order before request execution.
pub trait BeforeRequestList<Req>: BeforeRequest<Req> {
/// The hook returned by `BeforeRequestList::then`.
type Then<Next>: BeforeRequest<Req>
where
Next: BeforeRequest<Req>;

/// Returns a hook that, when run, runs two hooks, first `self` and then `next`.
fn then<Next: BeforeRequest<Req>>(self, next: Next) -> Self::Then<Next>;

/// Same as `then`, but helps the compiler with type inference when Next is a closure.
fn then_fn<
Next: FnMut(&mut context::Context, &Req) -> Fut,
Fut: Future<Output = Result<(), ServerError>>,
>(
self,
next: Next,
) -> Self::Then<Next>
where
Self: Sized,
{
self.then(next)
}

/// The service fn returned by `BeforeRequestList::serving`.
type Serve<S: Serve<Req = Req>>: Serve<Req = Req>;

/// Runs the list of request hooks before execution of the given serve fn.
/// This is equivalent to `serve.before(before_request_chain)` but may be syntactically nicer.
fn serving<S: Serve<Req = Req>>(self, serve: S) -> Self::Serve<S>;
}

impl<F, Fut, Req> BeforeRequest<Req> for F
where
F: FnMut(&mut context::Context, &Req) -> Fut,
Expand All @@ -33,27 +65,19 @@ where
}

/// A Service function that runs a hook before request execution.
pub struct BeforeRequestHook<Serv, Hook> {
#[derive(Clone)]
pub struct HookThenServe<Serv, Hook> {
serve: Serv,
hook: Hook,
}

impl<Serv, Hook> BeforeRequestHook<Serv, Hook> {
impl<Serv, Hook> HookThenServe<Serv, Hook> {
pub(crate) fn new(serve: Serv, hook: Hook) -> Self {
Self { serve, hook }
}
}

impl<Serv: Clone, Hook: Clone> Clone for BeforeRequestHook<Serv, Hook> {
fn clone(&self) -> Self {
Self {
serve: self.serve.clone(),
hook: self.hook.clone(),
}
}
}

impl<Serv, Hook> Serve for BeforeRequestHook<Serv, Hook>
impl<Serv, Hook> Serve for HookThenServe<Serv, Hook>
where
Serv: Serve,
Hook: BeforeRequest<Serv::Req>,
Expand All @@ -66,10 +90,121 @@ where
mut ctx: context::Context,
req: Self::Req,
) -> Result<Serv::Resp, ServerError> {
let BeforeRequestHook {
let HookThenServe {
serve, mut hook, ..
} = self;
hook.before(&mut ctx, &req).await?;
serve.serve(ctx, req).await
}
}

/// Returns a request hook builder that runs a series of hooks before request execution.
///
/// Example
///
/// ```rust
/// use futures::{executor::block_on, future};
/// use tarpc::{context, ServerError, server::{Serve, serve, request_hook::{self,
/// BeforeRequest, BeforeRequestList}}};
/// use std::{cell::Cell, io};
///
/// let i = Cell::new(0);
/// let serve = request_hook::before()
/// .then_fn(|_, _| async {
/// assert!(i.get() == 0);
/// i.set(1);
/// Ok(())
/// })
/// .then_fn(|_, _| async {
/// assert!(i.get() == 1);
/// i.set(2);
/// Ok(())
/// })
/// .serving(serve(|_ctx, i| async move { Ok(i + 1) }));
/// let response = serve.clone().serve(context::current(), 1);
/// assert!(block_on(response).is_ok());
/// assert!(i.get() == 2);
/// ```
pub fn before() -> BeforeRequestNil {
BeforeRequestNil
}

/// A list of hooks that run in order before a request is executed.
#[derive(Clone, Copy)]
pub struct BeforeRequestCons<First, Rest>(First, Rest);

/// A noop hook that runs before a request is executed.
#[derive(Clone, Copy)]
pub struct BeforeRequestNil;

impl<Req, First: BeforeRequest<Req>, Rest: BeforeRequest<Req>> BeforeRequest<Req>
for BeforeRequestCons<First, Rest>
{
async fn before(&mut self, ctx: &mut context::Context, req: &Req) -> Result<(), ServerError> {
let BeforeRequestCons(first, rest) = self;
first.before(ctx, req).await?;
rest.before(ctx, req).await?;
Ok(())
}
}

impl<Req> BeforeRequest<Req> for BeforeRequestNil {
async fn before(&mut self, _: &mut context::Context, _: &Req) -> Result<(), ServerError> {
Ok(())
}
}

impl<Req, First: BeforeRequest<Req>, Rest: BeforeRequestList<Req>> BeforeRequestList<Req>
for BeforeRequestCons<First, Rest>
{
type Then<Next> = BeforeRequestCons<First, Rest::Then<Next>> where Next: BeforeRequest<Req>;

fn then<Next: BeforeRequest<Req>>(self, next: Next) -> Self::Then<Next> {
let BeforeRequestCons(first, rest) = self;
BeforeRequestCons(first, rest.then(next))
}

type Serve<S: Serve<Req = Req>> = HookThenServe<S, Self>;

fn serving<S: Serve<Req = Req>>(self, serve: S) -> Self::Serve<S> {
HookThenServe::new(serve, self)
}
}

impl<Req> BeforeRequestList<Req> for BeforeRequestNil {
type Then<Next> = BeforeRequestCons<Next, BeforeRequestNil> where Next: BeforeRequest<Req>;

fn then<Next: BeforeRequest<Req>>(self, next: Next) -> Self::Then<Next> {
BeforeRequestCons(next, BeforeRequestNil)
}

type Serve<S: Serve<Req = Req>> = S;

fn serving<S: Serve<Req = Req>>(self, serve: S) -> S {
serve
}
}

#[test]
fn before_request_list() {
use crate::server::serve;
use futures::executor::block_on;
use std::cell::Cell;

let i = Cell::new(0);
let serve = before()
.then_fn(|_, _| async {
assert!(i.get() == 0);
i.set(1);
Ok(())
})
.then_fn(|_, _| async {
assert!(i.get() == 1);
i.set(2);
Ok(())
})
.serving(serve(|_ctx, i| async move { Ok(i + 1) }));
let response = serve.clone().serve(context::current(), 1);
assert!(block_on(response).is_ok());
assert!(i.get() == 2);
}
12 changes: 5 additions & 7 deletions tarpc/src/server/request_hook/before_and_after.rs
Expand Up @@ -11,13 +11,13 @@ use crate::{context, server::Serve, ServerError};
use std::marker::PhantomData;

/// A Service function that runs a hook both before and after request execution.
pub struct BeforeAndAfterRequestHook<Req, Resp, Serv, Hook> {
pub struct HookThenServeThenHook<Req, Resp, Serv, Hook> {
serve: Serv,
hook: Hook,
fns: PhantomData<(fn(Req), fn(Resp))>,
}

impl<Req, Resp, Serv, Hook> BeforeAndAfterRequestHook<Req, Resp, Serv, Hook> {
impl<Req, Resp, Serv, Hook> HookThenServeThenHook<Req, Resp, Serv, Hook> {
pub(crate) fn new(serve: Serv, hook: Hook) -> Self {
Self {
serve,
Expand All @@ -27,9 +27,7 @@ impl<Req, Resp, Serv, Hook> BeforeAndAfterRequestHook<Req, Resp, Serv, Hook> {
}
}

impl<Req, Resp, Serv: Clone, Hook: Clone> Clone
for BeforeAndAfterRequestHook<Req, Resp, Serv, Hook>
{
impl<Req, Resp, Serv: Clone, Hook: Clone> Clone for HookThenServeThenHook<Req, Resp, Serv, Hook> {
fn clone(&self) -> Self {
Self {
serve: self.serve.clone(),
Expand All @@ -39,7 +37,7 @@ impl<Req, Resp, Serv: Clone, Hook: Clone> Clone
}
}

impl<Req, Resp, Serv, Hook> Serve for BeforeAndAfterRequestHook<Req, Resp, Serv, Hook>
impl<Req, Resp, Serv, Hook> Serve for HookThenServeThenHook<Req, Resp, Serv, Hook>
where
Serv: Serve<Req = Req, Resp = Resp>,
Hook: BeforeRequest<Req> + AfterRequest<Resp>,
Expand All @@ -48,7 +46,7 @@ where
type Resp = Resp;

async fn serve(self, mut ctx: context::Context, req: Req) -> Result<Serv::Resp, ServerError> {
let BeforeAndAfterRequestHook {
let HookThenServeThenHook {
serve, mut hook, ..
} = self;
hook.before(&mut ctx, &req).await?;
Expand Down

0 comments on commit a6758fd

Please sign in to comment.