Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

allow for function middlewares by dropping Debug bound and add example #545

Merged
merged 3 commits into from May 28, 2020
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
101 changes: 101 additions & 0 deletions examples/middleware.rs
@@ -0,0 +1,101 @@
use std::future::Future;
use std::pin::Pin;
use std::sync::atomic::{AtomicUsize, Ordering};
use std::sync::Arc;
use tide::{Middleware, Next, Request, Response, Result, StatusCode};

#[derive(Debug)]
struct User {
name: String,
}

#[derive(Default)]
struct UserDatabase;
impl UserDatabase {
async fn find_user(&self) -> Option<User> {
Some(User {
name: "nori".into(),
})
}
}

// This is an example of a function middleware that uses the
// application state. Because it depends on a specific request state,
// it would likely be closely tied to a specific application
fn user_loader<'a>(
mut request: Request<UserDatabase>,
next: Next<'a, UserDatabase>,
) -> Pin<Box<dyn Future<Output = Result> + Send + 'a>> {
Box::pin(async {
if let Some(user) = request.state().find_user().await {
tide::log::trace!("user loaded", {user: user.name});
request.set_ext(user);
next.run(request).await
// this middleware only needs to run before the endpoint, so
// it just passes through the result of Next
} else {
// do not run endpoints, we could not find a user
Ok(Response::new(StatusCode::Unauthorized))
}
})
}

//
//
// this is an example of middleware that keeps its own state and could
// be provided as a third party crate
#[derive(Default)]
struct RequestCounterMiddleware {
requests_counted: Arc<AtomicUsize>,
}

impl RequestCounterMiddleware {
fn new(start: usize) -> Self {
Self {
requests_counted: Arc::new(AtomicUsize::new(start)),
}
}
}

struct RequestCount(usize);

impl<State: Send + Sync + 'static> Middleware<State> for RequestCounterMiddleware {
fn handle<'a>(
&'a self,
mut req: Request<State>,
next: Next<'a, State>,
) -> Pin<Box<dyn Future<Output = Result> + Send + 'a>> {
Box::pin(async move {
let count = self.requests_counted.fetch_add(1, Ordering::Relaxed);
tide::log::trace!("request counter", { count: count });
req.set_ext(RequestCount(count));

let mut response = next.run(req).await?;

response = response.set_header("request-number", count.to_string());
Ok(response)
})
}
}

#[async_std::main]
async fn main() -> Result<()> {
tide::log::start();
let mut app = tide::with_state(UserDatabase::default());

app.middleware(user_loader);
app.middleware(RequestCounterMiddleware::new(0));

app.at("/").get(|req: Request<_>| async move {
let count: &RequestCount = req.ext().unwrap();
let user: &User = req.ext().unwrap();

Ok(format!(
"Hello {}, this was request number {}!",
user.name, count.0
))
});

app.listen("127.0.0.1:8080").await?;
Ok(())
}
6 changes: 5 additions & 1 deletion src/middleware.rs
Expand Up @@ -17,9 +17,13 @@ pub trait Middleware<State>: 'static + Send + Sync {
/// Asynchronously handle the request, and return a response.
fn handle<'a>(
&'a self,
cx: Request<State>,
request: Request<State>,
next: Next<'a, State>,
) -> BoxFuture<'a, crate::Result>;

fn name(&self) -> &str {
yoshuawuyts marked this conversation as resolved.
Show resolved Hide resolved
std::any::type_name::<Self>()
}
}

impl<State, F> Middleware<State> for F
Expand Down
6 changes: 2 additions & 4 deletions src/server.rs
Expand Up @@ -6,8 +6,6 @@ use async_std::prelude::*;
use async_std::sync::Arc;
use async_std::task;

use std::fmt::Debug;

use crate::cookies;
use crate::log;
use crate::middleware::{Middleware, Next};
Expand Down Expand Up @@ -269,9 +267,9 @@ impl<State: Send + Sync + 'static> Server<State> {
/// and is processed in the order in which it is applied.
pub fn middleware<M>(&mut self, middleware: M) -> &mut Self
where
M: Middleware<State> + Debug,
M: Middleware<State>,
{
log::trace!("Adding middleware {:?}", middleware);
log::trace!("Adding middleware {}", middleware.name());
let m = Arc::get_mut(&mut self.middleware)
.expect("Registering middleware is not possible after the Server has started");
m.push(Arc::new(middleware));
Expand Down