Skip to content

Commit

Permalink
Server: require State to be Clone
Browse files Browse the repository at this point in the history
Alternative to
#642

This approach is more flexible but requires the user ensure that their
state implements/derives `Clone`, or is wrapped in an `Arc`.

Co-authored-by: Jacob Rothstein <hi@jbr.me>
  • Loading branch information
Fishrock123 and jbr committed Jul 12, 2020
1 parent 3778706 commit da703eb
Show file tree
Hide file tree
Showing 10 changed files with 57 additions and 51 deletions.
1 change: 1 addition & 0 deletions Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,7 @@ serde = "1.0.102"
serde_json = "1.0.41"
route-recognizer = "0.2.0"
logtest = "2.0.0"
pin-project-lite = "0.1.7"

[dev-dependencies]
async-std = { version = "1.6.0", features = ["unstable", "attributes"] }
Expand Down
10 changes: 6 additions & 4 deletions examples/graphql.rs
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
use std::sync::Arc;

use async_std::task;
use juniper::{http::graphiql, http::GraphQLRequest, RootNode};
use std::sync::RwLock;
Expand Down Expand Up @@ -72,7 +74,7 @@ fn create_schema() -> Schema {
Schema::new(QueryRoot {}, MutationRoot {})
}

async fn handle_graphql(mut request: Request<State>) -> tide::Result {
async fn handle_graphql(mut request: Request<Arc<State>>) -> tide::Result {
let query: GraphQLRequest = request.body_json().await?;
let schema = create_schema(); // probably worth making the schema a singleton using lazy_static library
let response = query.execute(&schema, request.state());
Expand All @@ -87,17 +89,17 @@ async fn handle_graphql(mut request: Request<State>) -> tide::Result {
.build())
}

async fn handle_graphiql(_: Request<State>) -> tide::Result<impl Into<Response>> {
async fn handle_graphiql(_: Request<Arc<State>>) -> tide::Result<impl Into<Response>> {
Ok(Response::builder(200)
.body(graphiql::graphiql_source("/graphql"))
.content_type(mime::HTML))
}

fn main() -> std::io::Result<()> {
task::block_on(async {
let mut app = Server::with_state(State {
let mut app = Server::with_state(Arc::new(State {
users: RwLock::new(Vec::new()),
});
}));
app.at("/").get(Redirect::permanent("/graphiql"));
app.at("/graphql").post(handle_graphql);
app.at("/graphiql").get(handle_graphiql);
Expand Down
16 changes: 9 additions & 7 deletions examples/middleware.rs
Original file line number Diff line number Diff line change
Expand Up @@ -25,8 +25,8 @@ impl UserDatabase {
// application state. Because it depends on a specific request state,
// it would likely be closely tied to a specific application
fn user_loader<'a>(
mut request: Request<UserDatabase>,
next: Next<'a, UserDatabase>,
mut request: Request<Arc<UserDatabase>>,
next: Next<'a, Arc<UserDatabase>>,
) -> Pin<Box<dyn Future<Output = Result> + Send + 'a>> {
Box::pin(async {
if let Some(user) = request.state().find_user().await {
Expand Down Expand Up @@ -98,7 +98,7 @@ const INTERNAL_SERVER_ERROR_HTML_PAGE: &str = "<html><body>
#[async_std::main]
async fn main() -> Result<()> {
tide::log::start();
let mut app = tide::with_state(UserDatabase::default());
let mut app = tide::with_state(Arc::new(UserDatabase::default()));

app.middleware(After(|response: Response| async move {
let response = match response.status() {
Expand All @@ -120,10 +120,12 @@ async fn main() -> Result<()> {

app.middleware(user_loader);
app.middleware(RequestCounterMiddleware::new(0));
app.middleware(Before(|mut request: Request<UserDatabase>| async move {
request.set_ext(std::time::Instant::now());
request
}));
app.middleware(Before(
|mut request: Request<Arc<UserDatabase>>| async move {
request.set_ext(std::time::Instant::now());
request
},
));

app.at("/").get(|req: Request<_>| async move {
let count: &RequestCount = req.ext().unwrap();
Expand Down
8 changes: 5 additions & 3 deletions examples/upload.rs
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
use std::sync::Arc;

use async_std::{fs::OpenOptions, io};
use tempfile::TempDir;
use tide::prelude::*;
Expand All @@ -6,15 +8,15 @@ use tide::{Body, Request, Response, StatusCode};
#[async_std::main]
async fn main() -> Result<(), std::io::Error> {
tide::log::start();
let mut app = tide::with_state(tempfile::tempdir()?);
let mut app = tide::with_state(Arc::new(tempfile::tempdir()?));

// To test this example:
// $ cargo run --example upload
// $ curl -T ./README.md locahost:8080 # this writes the file to a temp directory
// $ curl localhost:8080/README.md # this reads the file from the same temp directory

app.at(":file")
.put(|req: Request<TempDir>| async move {
.put(|req: Request<Arc<TempDir>>| async move {
let path: String = req.param("file")?;
let fs_path = req.state().path().join(path);

Expand All @@ -33,7 +35,7 @@ async fn main() -> Result<(), std::io::Error> {

Ok(json!({ "bytes": bytes_written }))
})
.get(|req: Request<TempDir>| async move {
.get(|req: Request<Arc<TempDir>>| async move {
let path: String = req.param("file")?;
let fs_path = req.state().path().join(path);

Expand Down
4 changes: 1 addition & 3 deletions src/fs/serve_dir.rs
Original file line number Diff line number Diff line change
Expand Up @@ -62,8 +62,6 @@ where
mod test {
use super::*;

use async_std::sync::Arc;

use std::fs::{self, File};
use std::io::Write;

Expand All @@ -85,7 +83,7 @@ mod test {
let request = crate::http::Request::get(
crate::http::Url::parse(&format!("http://localhost/{}", path)).unwrap(),
);
crate::Request::new(Arc::new(()), request, vec![])
crate::Request::new((), request, vec![])
}

#[async_std::test]
Expand Down
7 changes: 4 additions & 3 deletions src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -255,6 +255,7 @@ pub fn new() -> server::Server<()> {
/// # use async_std::task::block_on;
/// # fn main() -> Result<(), std::io::Error> { block_on(async {
/// #
/// use std::sync::Arc;
/// use tide::Request;
///
/// /// The shared application state.
Expand All @@ -268,8 +269,8 @@ pub fn new() -> server::Server<()> {
/// };
///
/// // Initialize the application with state.
/// let mut app = tide::with_state(state);
/// app.at("/").get(|req: Request<State>| async move {
/// let mut app = tide::with_state(Arc::new(state));
/// app.at("/").get(|req: Request<Arc<State>>| async move {
/// Ok(format!("Hello, {}!", &req.state().name))
/// });
/// app.listen("127.0.0.1:8080").await?;
Expand All @@ -278,7 +279,7 @@ pub fn new() -> server::Server<()> {
/// ```
pub fn with_state<State>(state: State) -> server::Server<State>
where
State: Send + Sync + 'static,
State: Clone + Send + Sync + 'static,
{
Server::with_state(state)
}
Expand Down
39 changes: 19 additions & 20 deletions src/request.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,26 +4,29 @@ use route_recognizer::Params;

use std::ops::Index;
use std::pin::Pin;
use std::{fmt, str::FromStr, sync::Arc};
use std::{fmt, str::FromStr};

use crate::cookies::CookieData;
use crate::http::cookies::Cookie;
use crate::http::headers::{self, HeaderName, HeaderValues, ToHeaderValues};
use crate::http::{self, Body, Method, Mime, StatusCode, Url, Version};
use crate::Response;

/// An HTTP request.
///
/// The `Request` gives endpoints access to basic information about the incoming
/// request, route parameters, and various ways of accessing the request's body.
///
/// Requests also provide *extensions*, a type map primarily used for low-level
/// communication between middleware and endpoints.
#[derive(Debug)]
pub struct Request<State> {
pub(crate) state: Arc<State>,
pub(crate) req: http::Request,
pub(crate) route_params: Vec<Params>,
pin_project_lite::pin_project! {
/// An HTTP request.
///
/// The `Request` gives endpoints access to basic information about the incoming
/// request, route parameters, and various ways of accessing the request's body.
///
/// Requests also provide *extensions*, a type map primarily used for low-level
/// communication between middleware and endpoints.
#[derive(Debug)]
pub struct Request<State> {
pub(crate) state: State,
#[pin]
pub(crate) req: http::Request,
pub(crate) route_params: Vec<Params>,
}
}

#[derive(Debug)]
Expand All @@ -45,11 +48,7 @@ impl<T: fmt::Debug + fmt::Display> std::error::Error for ParamError<T> {}

impl<State> Request<State> {
/// Create a new `Request`.
pub(crate) fn new(
state: Arc<State>,
req: http_types::Request,
route_params: Vec<Params>,
) -> Self {
pub(crate) fn new(state: State, req: http_types::Request, route_params: Vec<Params>) -> Self {
Self {
state,
req,
Expand Down Expand Up @@ -550,11 +549,11 @@ impl<State> AsMut<http::Headers> for Request<State> {

impl<State> Read for Request<State> {
fn poll_read(
mut self: Pin<&mut Self>,
self: Pin<&mut Self>,
cx: &mut Context<'_>,
buf: &mut [u8],
) -> Poll<io::Result<usize>> {
Pin::new(&mut self.req).poll_read(cx, buf)
self.project().req.poll_read(cx, buf)
}
}

Expand Down
4 changes: 2 additions & 2 deletions src/route.rs
Original file line number Diff line number Diff line change
Expand Up @@ -102,8 +102,8 @@ impl<'a, State: Send + Sync + 'static> Route<'a, State> {
/// [`Server`]: struct.Server.html
pub fn nest<InnerState>(&mut self, service: crate::Server<InnerState>) -> &mut Self
where
State: Send + Sync + 'static,
InnerState: Send + Sync + 'static,
State: Clone + Send + Sync + 'static,
InnerState: Clone + Send + Sync + 'static,
{
self.prefix = true;
self.all(service);
Expand Down
17 changes: 9 additions & 8 deletions src/server.rs
Original file line number Diff line number Diff line change
Expand Up @@ -129,7 +129,7 @@ use crate::{Endpoint, Request, Route};
#[allow(missing_debug_implementations)]
pub struct Server<State> {
router: Arc<Router<State>>,
state: Arc<State>,
state: State,
middleware: Arc<Vec<Arc<dyn Middleware<State>>>>,
}

Expand Down Expand Up @@ -166,7 +166,7 @@ impl Default for Server<()> {
}
}

impl<State: Send + Sync + 'static> Server<State> {
impl<State: Clone + Send + Sync + 'static> Server<State> {
/// Create a new Tide server with shared application scoped state.
///
/// Application scoped state is useful for storing items
Expand All @@ -177,6 +177,7 @@ impl<State: Send + Sync + 'static> Server<State> {
/// # use async_std::task::block_on;
/// # fn main() -> Result<(), std::io::Error> { block_on(async {
/// #
/// use std::sync::Arc;
/// use tide::Request;
///
/// /// The shared application state.
Expand All @@ -190,8 +191,8 @@ impl<State: Send + Sync + 'static> Server<State> {
/// };
///
/// // Initialize the application with state.
/// let mut app = tide::with_state(state);
/// app.at("/").get(|req: Request<State>| async move {
/// let mut app = tide::with_state(Arc::new(state));
/// app.at("/").get(|req: Request<Arc<State>>| async move {
/// Ok(format!("Hello, {}!", &req.state().name))
/// });
/// app.listen("127.0.0.1:8080").await?;
Expand All @@ -202,7 +203,7 @@ impl<State: Send + Sync + 'static> Server<State> {
let mut server = Self {
router: Arc::new(Router::new()),
middleware: Arc::new(vec![]),
state: Arc::new(state),
state,
};
server.middleware(cookies::CookiesMiddleware::new());
server.middleware(log::LogMiddleware::new());
Expand Down Expand Up @@ -429,7 +430,7 @@ impl<State: Send + Sync + 'static> Server<State> {
}
}

impl<State> Clone for Server<State> {
impl<State: Clone> Clone for Server<State> {
fn clone(&self) -> Self {
Self {
router: self.router.clone(),
Expand All @@ -439,8 +440,8 @@ impl<State> Clone for Server<State> {
}
}

impl<State: Sync + Send + 'static, InnerState: Sync + Send + 'static> Endpoint<State>
for Server<InnerState>
impl<State: Clone + Sync + Send + 'static, InnerState: Clone + Sync + Send + 'static>
Endpoint<State> for Server<InnerState>
{
fn call<'a>(&'a self, req: Request<State>) -> BoxFuture<'a, crate::Result> {
let Request {
Expand Down
2 changes: 1 addition & 1 deletion tests/test_utils.rs
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@ pub trait ServerTestingExt {

impl<State> ServerTestingExt for Server<State>
where
State: Send + Sync + 'static,
State: Clone + Send + Sync + 'static,
{
fn request<'a>(
&'a self,
Expand Down

0 comments on commit da703eb

Please sign in to comment.