Skip to content

Commit

Permalink
try using routefinder
Browse files Browse the repository at this point in the history
  • Loading branch information
jbr committed Apr 22, 2021
1 parent d81eabc commit 81d566b
Show file tree
Hide file tree
Showing 5 changed files with 90 additions and 103 deletions.
10 changes: 5 additions & 5 deletions Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -3,9 +3,9 @@ name = "tide"
version = "0.16.0"
description = "A minimal and pragmatic Rust web application framework built for rapid development"
authors = [
"Aaron Turon <aturon@mozilla.com>",
"Yoshua Wuyts <yoshuawuyts@gmail.com>",
"Wonwoo Choi <chwo9843@gmail.com>",
"Aaron Turon <aturon@mozilla.com>",
"Yoshua Wuyts <yoshuawuyts@gmail.com>",
"Wonwoo Choi <chwo9843@gmail.com>",
]
documentation = "https://docs.rs/tide"
keywords = ["tide", "http", "web", "framework", "async"]
Expand Down Expand Up @@ -34,7 +34,7 @@ unstable = []

[dependencies]
async-h1 = { version = "2.3.0", optional = true }
async-session = { version = "2.0.1", optional = true }
async-session = { version = "2.0.1", optional = true }
async-sse = "4.0.1"
async-std = { version = "1.6.5", features = ["unstable"] }
async-trait = "0.1.41"
Expand All @@ -45,9 +45,9 @@ http-types = { version = "2.11.0", default-features = false, features = ["fs"] }
kv-log-macro = "1.0.7"
log = { version = "0.4.13", features = ["kv_unstable_std"] }
pin-project-lite = "0.2.0"
route-recognizer = "0.2.0"
serde = "1.0.117"
serde_json = "1.0.59"
routefinder = "0.1.1"

[dev-dependencies]
async-std = { version = "1.6.5", features = ["unstable", "attributes"] }
Expand Down
49 changes: 36 additions & 13 deletions src/request.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
use async_std::io::{self, prelude::*};
use async_std::task::{Context, Poll};
use route_recognizer::Params;
use routefinder::Captures;

use std::ops::Index;
use std::pin::Pin;
Expand All @@ -27,13 +27,13 @@ pin_project_lite::pin_project! {
pub(crate) state: State,
#[pin]
pub(crate) req: http::Request,
pub(crate) route_params: Vec<Params>,
pub(crate) route_params: Vec<Captures>,
}
}

impl<State> Request<State> {
/// Create a new `Request`.
pub(crate) fn new(state: State, req: http_types::Request, route_params: Vec<Params>) -> Self {
pub(crate) fn new(state: State, req: http_types::Request, route_params: Vec<Captures>) -> Self {
Self {
state,
req,
Expand Down Expand Up @@ -266,8 +266,7 @@ impl<State> Request<State> {
///
/// Returns the parameter as a `&str`, borrowed from this `Request`.
///
/// The name should *not* include the leading `:` or the trailing `*` (if
/// any).
/// The name should *not* include the leading `:`.
///
/// # Errors
///
Expand Down Expand Up @@ -297,10 +296,40 @@ impl<State> Request<State> {
self.route_params
.iter()
.rev()
.find_map(|params| params.find(key))
.find_map(|captures| captures.get(key))
.ok_or_else(|| format_err!("Param \"{}\" not found", key.to_string()))
}

/// Fetch the wildcard from the route, if it exists
///
/// Returns the parameter as a `&str`, borrowed from this `Request`.
///
/// # Examples
///
/// ```no_run
/// # use async_std::task::block_on;
/// # fn main() -> Result<(), std::io::Error> { block_on(async {
/// #
/// use tide::{Request, Result};
///
/// async fn greet(req: Request<()>) -> Result<String> {
/// let name = req.wildcard().unwrap_or("world");
/// Ok(format!("Hello, {}!", name))
/// }
///
/// let mut app = tide::new();
/// app.at("/hello/*").get(greet);
/// app.listen("127.0.0.1:8080").await?;
/// #
/// # Ok(()) })}
/// ```
pub fn wildcard(&self) -> Option<&str> {
self.route_params
.iter()
.rev()
.find_map(|captures| captures.wildcard())
}

/// Parse the URL query component into a struct, using [serde_qs](https://docs.rs/serde_qs). To
/// get the entire query as an unparsed string, use `request.url().query()`.
///
Expand Down Expand Up @@ -565,7 +594,7 @@ impl<State> From<Request<State>> for http::Request {

impl<State: Default> From<http_types::Request> for Request<State> {
fn from(request: http_types::Request) -> Request<State> {
Request::new(State::default(), request, Vec::<Params>::new())
Request::new(State::default(), request, vec![])
}
}

Expand Down Expand Up @@ -635,9 +664,3 @@ impl<State> Index<&str> for Request<State> {
&self.req[name]
}
}

pub(crate) fn rest(route_params: &[Params]) -> Option<&str> {
route_params
.last()
.and_then(|params| params.find("--tide-path-rest"))
}
22 changes: 8 additions & 14 deletions src/route.rs
Original file line number Diff line number Diff line change
Expand Up @@ -153,13 +153,7 @@ impl<'a, State: Clone + Send + Sync + 'static> Route<'a, State> {
pub fn method(&mut self, method: http_types::Method, ep: impl Endpoint<State>) -> &mut Self {
if self.prefix {
let ep = StripPrefixEndpoint::new(ep);

self.router.add(
&self.path,
method,
MiddlewareEndpoint::wrap_with_middleware(ep.clone(), &self.middleware),
);
let wildcard = self.at("*--tide-path-rest");
let wildcard = self.at("*");
wildcard.router.add(
&wildcard.path,
method,
Expand All @@ -181,12 +175,7 @@ impl<'a, State: Clone + Send + Sync + 'static> Route<'a, State> {
pub fn all(&mut self, ep: impl Endpoint<State>) -> &mut Self {
if self.prefix {
let ep = StripPrefixEndpoint::new(ep);

self.router.add_all(
&self.path,
MiddlewareEndpoint::wrap_with_middleware(ep.clone(), &self.middleware),
);
let wildcard = self.at("*--tide-path-rest");
let wildcard = self.at("*");
wildcard.router.add_all(
&wildcard.path,
MiddlewareEndpoint::wrap_with_middleware(ep, &wildcard.middleware),
Expand Down Expand Up @@ -283,7 +272,12 @@ where
route_params,
} = req;

let rest = crate::request::rest(&route_params).unwrap_or("");
let rest = route_params
.iter()
.rev()
.find_map(|captures| captures.wildcard())
.unwrap_or_default();

req.url_mut().set_path(&rest);

self.0
Expand Down
37 changes: 23 additions & 14 deletions src/router.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
use route_recognizer::{Match, Params, Router as MethodRouter};
use routefinder::{Captures, Router as MethodRouter};
use std::collections::HashMap;

use crate::endpoint::DynEndpoint;
Expand All @@ -14,11 +14,19 @@ pub(crate) struct Router<State> {
all_method_router: MethodRouter<Box<DynEndpoint<State>>>,
}

impl<State> std::fmt::Debug for Router<State> {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("Router")
.field("method_map", &self.method_map)
.field("all_method_router", &self.all_method_router)
.finish()
}
}

/// The result of routing a URL
#[allow(missing_debug_implementations)]
pub(crate) struct Selection<'a, State> {
pub(crate) endpoint: &'a DynEndpoint<State>,
pub(crate) params: Params,
pub(crate) params: Captures,
}

impl<State: Clone + Send + Sync + 'static> Router<State> {
Expand All @@ -39,26 +47,27 @@ impl<State: Clone + Send + Sync + 'static> Router<State> {
.entry(method)
.or_insert_with(MethodRouter::new)
.add(path, ep)
.unwrap()
}

pub(crate) fn add_all(&mut self, path: &str, ep: Box<DynEndpoint<State>>) {
self.all_method_router.add(path, ep)
self.all_method_router.add(path, ep).unwrap()
}

pub(crate) fn route(&self, path: &str, method: http_types::Method) -> Selection<'_, State> {
if let Some(Match { handler, params }) = self
if let Some(m) = self
.method_map
.get(&method)
.and_then(|r| r.recognize(path).ok())
.and_then(|r| r.best_match(path))
{
Selection {
endpoint: &**handler,
params,
endpoint: m.handler(),
params: m.captures(),
}
} else if let Ok(Match { handler, params }) = self.all_method_router.recognize(path) {
} else if let Some(m) = self.all_method_router.best_match(path) {
Selection {
endpoint: &**handler,
params,
endpoint: m.handler(),
params: m.captures(),
}
} else if method == http_types::Method::Head {
// If it is a HTTP HEAD request then check if there is a callback in the endpoints map
Expand All @@ -69,18 +78,18 @@ impl<State: Clone + Send + Sync + 'static> Router<State> {
.method_map
.iter()
.filter(|(k, _)| **k != method)
.any(|(_, r)| r.recognize(path).is_ok())
.any(|(_, r)| r.best_match(path).is_some())
{
// If this `path` can be handled by a callback registered with a different HTTP method
// should return 405 Method Not Allowed
Selection {
endpoint: &method_not_allowed,
params: Params::new(),
params: Captures::default(),
}
} else {
Selection {
endpoint: &not_found_endpoint,
params: Params::new(),
params: Captures::default(),
}
}
}
Expand Down
75 changes: 18 additions & 57 deletions tests/wildcard.rs
Original file line number Diff line number Diff line change
Expand Up @@ -22,18 +22,22 @@ async fn add_two(req: Request<()>) -> Result<String, tide::Error> {
Ok((one + two).to_string())
}

async fn echo_path(req: Request<()>) -> Result<String, tide::Error> {
match req.param("path") {
async fn echo_param(req: Request<()>) -> tide::Result<tide::Response> {
match req.param("param") {
Ok(path) => Ok(path.into()),
Err(mut err) => {
err.set_status(StatusCode::BadRequest);
Err(err)
}
Err(_) => Ok(StatusCode::NotFound.into()),
}
}

async fn echo_wildcard(req: Request<()>) -> tide::Result<tide::Response> {
match req.wildcard() {
Some(path) => Ok(path.into()),
None => Ok(StatusCode::NotFound.into()),
}
}

#[async_std::test]
async fn wildcard() -> tide::Result<()> {
async fn param() -> tide::Result<()> {
let mut app = tide::Server::new();
app.at("/add_one/:num").get(add_one);
assert_eq!(app.get("/add_one/3").recv_string().await?, "4");
Expand Down Expand Up @@ -61,20 +65,21 @@ async fn not_found_error() -> tide::Result<()> {
}

#[async_std::test]
async fn wild_path() -> tide::Result<()> {
async fn wildcard() -> tide::Result<()> {
let mut app = tide::new();
app.at("/echo/*path").get(echo_path);
app.at("/echo/*").get(echo_wildcard);
assert_eq!(app.get("/echo/some_path").recv_string().await?, "some_path");
assert_eq!(
app.get("/echo/multi/segment/path").recv_string().await?,
"multi/segment/path"
);
assert_eq!(app.get("/echo/").await?.status(), StatusCode::NotFound);
assert_eq!(app.get("/echo/").await?.status(), StatusCode::Ok);
assert_eq!(app.get("/echo").await?.status(), StatusCode::Ok);
Ok(())
}

#[async_std::test]
async fn multi_wildcard() -> tide::Result<()> {
async fn multi_param() -> tide::Result<()> {
let mut app = tide::new();
app.at("/add_two/:one/:two/").get(add_two);
assert_eq!(app.get("/add_two/1/2/").recv_string().await?, "3");
Expand All @@ -84,9 +89,9 @@ async fn multi_wildcard() -> tide::Result<()> {
}

#[async_std::test]
async fn wild_last_segment() -> tide::Result<()> {
async fn wildcard_last_segment() -> tide::Result<()> {
let mut app = tide::new();
app.at("/echo/:path/*").get(echo_path);
app.at("/echo/:param/*").get(echo_param);
assert_eq!(app.get("/echo/one/two").recv_string().await?, "one");
assert_eq!(
app.get("/echo/one/two/three/four").recv_string().await?,
Expand All @@ -95,50 +100,6 @@ async fn wild_last_segment() -> tide::Result<()> {
Ok(())
}

#[async_std::test]
async fn invalid_wildcard() -> tide::Result<()> {
let mut app = tide::new();
app.at("/echo/*path/:one/").get(echo_path);
assert_eq!(
app.get("/echo/one/two").await?.status(),
StatusCode::NotFound
);
Ok(())
}

#[async_std::test]
async fn nameless_wildcard() -> tide::Result<()> {
let mut app = tide::Server::new();
app.at("/echo/:").get(|_| async { Ok("") });
assert_eq!(
app.get("/echo/one/two").await?.status(),
StatusCode::NotFound
);
assert_eq!(app.get("/echo/one").await?.status(), StatusCode::Ok);
Ok(())
}

#[async_std::test]
async fn nameless_internal_wildcard() -> tide::Result<()> {
let mut app = tide::new();
app.at("/echo/:/:path").get(echo_path);
assert_eq!(app.get("/echo/one").await?.status(), StatusCode::NotFound);
assert_eq!(app.get("/echo/one/two").recv_string().await?, "two");
Ok(())
}

#[async_std::test]
async fn nameless_internal_wildcard2() -> tide::Result<()> {
let mut app = tide::new();
app.at("/echo/:/:path").get(|req: Request<()>| async move {
assert_eq!(req.param("path")?, "two");
Ok("")
});

assert!(app.get("/echo/one/two").await?.status().is_success());
Ok(())
}

#[async_std::test]
async fn ambiguous_router_wildcard_vs_star() -> tide::Result<()> {
let mut app = tide::new();
Expand Down

0 comments on commit 81d566b

Please sign in to comment.