Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

try using routefinder #802

Merged
merged 4 commits into from
Dec 6, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
2 changes: 1 addition & 1 deletion Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -45,9 +45,9 @@ http-types = { version = "2.11.0", default-features = false, features = ["fs"] }
kv-log-macro = "1.0.7"
log = { version = "0.4.13", features = ["kv_unstable_std"] }
pin-project-lite = "0.2.0"
route-recognizer = "0.2.0"
serde = "1.0.117"
serde_json = "1.0.59"
routefinder = "0.4.0"

[dev-dependencies]
async-std = { version = "1.6.5", features = ["unstable", "attributes"] }
Expand Down
53 changes: 40 additions & 13 deletions src/request.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
use async_std::io::{self, prelude::*};
use async_std::task::{Context, Poll};
use route_recognizer::Params;
use routefinder::Captures;

use std::ops::Index;
use std::pin::Pin;
Expand All @@ -27,13 +27,17 @@ pin_project_lite::pin_project! {
pub(crate) state: State,
#[pin]
pub(crate) req: http::Request,
pub(crate) route_params: Vec<Params>,
pub(crate) route_params: Vec<Captures<'static, 'static>>,
}
}

impl<State> Request<State> {
/// Create a new `Request`.
pub(crate) fn new(state: State, req: http_types::Request, route_params: Vec<Params>) -> Self {
pub(crate) fn new(
state: State,
req: http_types::Request,
route_params: Vec<Captures<'static, 'static>>,
) -> Self {
Self {
state,
req,
Expand Down Expand Up @@ -266,8 +270,7 @@ impl<State> Request<State> {
///
/// Returns the parameter as a `&str`, borrowed from this `Request`.
///
/// The name should *not* include the leading `:` or the trailing `*` (if
/// any).
/// The name should *not* include the leading `:`.
///
/// # Errors
///
Expand Down Expand Up @@ -297,10 +300,40 @@ impl<State> Request<State> {
self.route_params
.iter()
.rev()
.find_map(|params| params.find(key))
.find_map(|captures| captures.get(key))
.ok_or_else(|| format_err!("Param \"{}\" not found", key.to_string()))
}

/// Fetch the wildcard from the route, if it exists
///
/// Returns the parameter as a `&str`, borrowed from this `Request`.
///
/// # Examples
///
/// ```no_run
/// # use async_std::task::block_on;
/// # fn main() -> Result<(), std::io::Error> { block_on(async {
/// #
/// use tide::{Request, Result};
///
/// async fn greet(req: Request<()>) -> Result<String> {
/// let name = req.wildcard().unwrap_or("world");
/// Ok(format!("Hello, {}!", name))
/// }
///
/// let mut app = tide::new();
/// app.at("/hello/*").get(greet);
/// app.listen("127.0.0.1:8080").await?;
/// #
/// # Ok(()) })}
/// ```
pub fn wildcard(&self) -> Option<&str> {
self.route_params
.iter()
.rev()
.find_map(|captures| captures.wildcard())
}

/// Parse the URL query component into a struct, using [serde_qs](https://docs.rs/serde_qs). To
/// get the entire query as an unparsed string, use `request.url().query()`.
///
Expand Down Expand Up @@ -565,7 +598,7 @@ impl<State> From<Request<State>> for http::Request {

impl<State: Default> From<http_types::Request> for Request<State> {
fn from(request: http_types::Request) -> Request<State> {
Request::new(State::default(), request, Vec::<Params>::new())
Request::new(State::default(), request, vec![])
}
}

Expand Down Expand Up @@ -635,9 +668,3 @@ impl<State> Index<&str> for Request<State> {
&self.req[name]
}
}

pub(crate) fn rest(route_params: &[Params]) -> Option<&str> {
route_params
.last()
.and_then(|params| params.find("--tide-path-rest"))
}
22 changes: 8 additions & 14 deletions src/route.rs
Original file line number Diff line number Diff line change
Expand Up @@ -153,13 +153,7 @@ impl<'a, State: Clone + Send + Sync + 'static> Route<'a, State> {
pub fn method(&mut self, method: http_types::Method, ep: impl Endpoint<State>) -> &mut Self {
if self.prefix {
let ep = StripPrefixEndpoint::new(ep);

self.router.add(
&self.path,
method,
MiddlewareEndpoint::wrap_with_middleware(ep.clone(), &self.middleware),
);
let wildcard = self.at("*--tide-path-rest");
let wildcard = self.at("*");
wildcard.router.add(
&wildcard.path,
method,
Expand All @@ -181,12 +175,7 @@ impl<'a, State: Clone + Send + Sync + 'static> Route<'a, State> {
pub fn all(&mut self, ep: impl Endpoint<State>) -> &mut Self {
if self.prefix {
let ep = StripPrefixEndpoint::new(ep);

self.router.add_all(
&self.path,
MiddlewareEndpoint::wrap_with_middleware(ep.clone(), &self.middleware),
);
let wildcard = self.at("*--tide-path-rest");
let wildcard = self.at("*");
wildcard.router.add_all(
&wildcard.path,
MiddlewareEndpoint::wrap_with_middleware(ep, &wildcard.middleware),
Expand Down Expand Up @@ -283,7 +272,12 @@ where
route_params,
} = req;

let rest = crate::request::rest(&route_params).unwrap_or("");
let rest = route_params
.iter()
.rev()
.find_map(|captures| captures.wildcard())
.unwrap_or_default();

req.url_mut().set_path(rest);

self.0
Expand Down
37 changes: 23 additions & 14 deletions src/router.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
use route_recognizer::{Match, Params, Router as MethodRouter};
use routefinder::{Captures, Router as MethodRouter};
use std::collections::HashMap;

use crate::endpoint::DynEndpoint;
Expand All @@ -14,11 +14,19 @@ pub(crate) struct Router<State> {
all_method_router: MethodRouter<Box<DynEndpoint<State>>>,
}

impl<State> std::fmt::Debug for Router<State> {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("Router")
.field("method_map", &self.method_map)
.field("all_method_router", &self.all_method_router)
.finish()
}
}

/// The result of routing a URL
#[allow(missing_debug_implementations)]
pub(crate) struct Selection<'a, State> {
pub(crate) endpoint: &'a DynEndpoint<State>,
pub(crate) params: Params,
pub(crate) params: Captures<'static, 'static>,
}

impl<State: Clone + Send + Sync + 'static> Router<State> {
Expand All @@ -39,26 +47,27 @@ impl<State: Clone + Send + Sync + 'static> Router<State> {
.entry(method)
.or_insert_with(MethodRouter::new)
.add(path, ep)
.unwrap()
}

pub(crate) fn add_all(&mut self, path: &str, ep: Box<DynEndpoint<State>>) {
self.all_method_router.add(path, ep)
self.all_method_router.add(path, ep).unwrap()
}

pub(crate) fn route(&self, path: &str, method: http_types::Method) -> Selection<'_, State> {
if let Some(Match { handler, params }) = self
if let Some(m) = self
.method_map
.get(&method)
.and_then(|r| r.recognize(path).ok())
.and_then(|r| r.best_match(path))
{
Selection {
endpoint: &**handler,
params,
endpoint: m.handler(),
params: m.captures().into_owned(),
}
} else if let Ok(Match { handler, params }) = self.all_method_router.recognize(path) {
} else if let Some(m) = self.all_method_router.best_match(path) {
Selection {
endpoint: &**handler,
params,
endpoint: m.handler(),
params: m.captures().into_owned(),
}
} else if method == http_types::Method::Head {
// If it is a HTTP HEAD request then check if there is a callback in the endpoints map
Expand All @@ -69,18 +78,18 @@ impl<State: Clone + Send + Sync + 'static> Router<State> {
.method_map
.iter()
.filter(|(k, _)| **k != method)
.any(|(_, r)| r.recognize(path).is_ok())
.any(|(_, r)| r.best_match(path).is_some())
{
// If this `path` can be handled by a callback registered with a different HTTP method
// should return 405 Method Not Allowed
Selection {
endpoint: &method_not_allowed,
params: Params::new(),
params: Captures::default(),
}
} else {
Selection {
endpoint: &not_found_endpoint,
params: Params::new(),
params: Captures::default(),
}
}
}
Expand Down
75 changes: 18 additions & 57 deletions tests/wildcard.rs
Original file line number Diff line number Diff line change
Expand Up @@ -22,18 +22,22 @@ async fn add_two(req: Request<()>) -> Result<String, tide::Error> {
Ok((one + two).to_string())
}

async fn echo_path(req: Request<()>) -> Result<String, tide::Error> {
match req.param("path") {
async fn echo_param(req: Request<()>) -> tide::Result<tide::Response> {
match req.param("param") {
Ok(path) => Ok(path.into()),
Err(mut err) => {
err.set_status(StatusCode::BadRequest);
Err(err)
}
Err(_) => Ok(StatusCode::NotFound.into()),
}
}

async fn echo_wildcard(req: Request<()>) -> tide::Result<tide::Response> {
match req.wildcard() {
Some(path) => Ok(path.into()),
None => Ok(StatusCode::NotFound.into()),
}
}

#[async_std::test]
async fn wildcard() -> tide::Result<()> {
async fn param() -> tide::Result<()> {
let mut app = tide::Server::new();
app.at("/add_one/:num").get(add_one);
assert_eq!(app.get("/add_one/3").recv_string().await?, "4");
Expand Down Expand Up @@ -61,20 +65,21 @@ async fn not_found_error() -> tide::Result<()> {
}

#[async_std::test]
async fn wild_path() -> tide::Result<()> {
async fn wildcard() -> tide::Result<()> {
let mut app = tide::new();
app.at("/echo/*path").get(echo_path);
app.at("/echo/*").get(echo_wildcard);
assert_eq!(app.get("/echo/some_path").recv_string().await?, "some_path");
assert_eq!(
app.get("/echo/multi/segment/path").recv_string().await?,
"multi/segment/path"
);
assert_eq!(app.get("/echo/").await?.status(), StatusCode::NotFound);
assert_eq!(app.get("/echo/").await?.status(), StatusCode::Ok);
assert_eq!(app.get("/echo").await?.status(), StatusCode::Ok);
Ok(())
}

#[async_std::test]
async fn multi_wildcard() -> tide::Result<()> {
async fn multi_param() -> tide::Result<()> {
let mut app = tide::new();
app.at("/add_two/:one/:two/").get(add_two);
assert_eq!(app.get("/add_two/1/2/").recv_string().await?, "3");
Expand All @@ -84,9 +89,9 @@ async fn multi_wildcard() -> tide::Result<()> {
}

#[async_std::test]
async fn wild_last_segment() -> tide::Result<()> {
async fn wildcard_last_segment() -> tide::Result<()> {
let mut app = tide::new();
app.at("/echo/:path/*").get(echo_path);
app.at("/echo/:param/*").get(echo_param);
assert_eq!(app.get("/echo/one/two").recv_string().await?, "one");
assert_eq!(
app.get("/echo/one/two/three/four").recv_string().await?,
Expand All @@ -95,50 +100,6 @@ async fn wild_last_segment() -> tide::Result<()> {
Ok(())
}

#[async_std::test]
async fn invalid_wildcard() -> tide::Result<()> {
let mut app = tide::new();
app.at("/echo/*path/:one/").get(echo_path);
assert_eq!(
app.get("/echo/one/two").await?.status(),
StatusCode::NotFound
);
Ok(())
}

#[async_std::test]
async fn nameless_wildcard() -> tide::Result<()> {
let mut app = tide::Server::new();
app.at("/echo/:").get(|_| async { Ok("") });
assert_eq!(
app.get("/echo/one/two").await?.status(),
StatusCode::NotFound
);
assert_eq!(app.get("/echo/one").await?.status(), StatusCode::Ok);
Ok(())
}

#[async_std::test]
async fn nameless_internal_wildcard() -> tide::Result<()> {
let mut app = tide::new();
app.at("/echo/:/:path").get(echo_path);
assert_eq!(app.get("/echo/one").await?.status(), StatusCode::NotFound);
assert_eq!(app.get("/echo/one/two").recv_string().await?, "two");
Ok(())
}

#[async_std::test]
async fn nameless_internal_wildcard2() -> tide::Result<()> {
let mut app = tide::new();
app.at("/echo/:/:path").get(|req: Request<()>| async move {
assert_eq!(req.param("path")?, "two");
Ok("")
});

assert!(app.get("/echo/one/two").await?.status().is_success());
Ok(())
}

#[async_std::test]
async fn ambiguous_router_wildcard_vs_star() -> tide::Result<()> {
let mut app = tide::new();
Expand Down