Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix CORS & testing::TestResponse::header #153

Merged
merged 3 commits into from
May 4, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
114 changes: 105 additions & 9 deletions ohkami/src/builtin/fang/cors.rs
Original file line number Diff line number Diff line change
Expand Up @@ -65,7 +65,7 @@ pub(crate) enum AccessControlAllowOrigin {
}

impl CORS {
/// Create `CORS` fang for specified `AllowOrigin` as `Access-Control-Allow-Origin` header.\
/// Create `CORS` fang using given `AllowOrigin` as `Access-Control-Allow-Origin` header value.\
/// (Both `"*"` and a speciffic origin are available)
#[allow(non_snake_case)]
pub const fn new(AllowOrigin: &'static str) -> Self {
Expand All @@ -74,20 +74,24 @@ impl CORS {
Self {
AllowOrigin: AccessControlAllowOrigin::from_literal(AllowOrigin),
AllowCredentials: false,
AllowMethods: Some(&[GET, HEAD, PUT, POST, DELETE, PATCH]),
AllowMethods: Some(&[GET, PUT, POST, PATCH, DELETE, OPTIONS, HEAD]),
AllowHeaders: None,
ExposeHeaders: None,
MaxAge: None,
}
}

pub const fn AllowCredentials(mut self) -> Self {
pub fn AllowCredentials(mut self) -> Self {
if self.AllowOrigin.is_any() {
panic!("\
The value of the 'Access-Control-Allow-Origin' header in the response \
must not be the wildcard '*' when the request's credentials mode is 'include'.\
")
#[cfg(feature="DEBUG")] eprintln!("\
[WRANING] \
'Access-Control-Allow-Origin' header \
must not have wildcard '*' when the request's credentials mode is 'include' \
");

return self
}

self.AllowCredentials = true;
self
}
Expand Down Expand Up @@ -150,8 +154,8 @@ impl<Inner: FangProc> FangProc for CORSProc<Inner> {
}
if let Some(allow_methods) = self.cors.AllowMethods {
let methods_string = allow_methods.iter()
.map(Method::as_str)
.fold(String::new(), |mut ms, m| {ms.push_str(m); ms});
.map(Method::as_str).collect::<Vec<_>>()
.join(",");
h = h.AccessControlAllowMethods(methods_string);
}
if let Some(allow_headers_string) = match self.cors.AllowHeaders {
Expand All @@ -166,6 +170,98 @@ impl<Inner: FangProc> FangProc for CORSProc<Inner> {
res.status = Status::NoContent;
}

#[cfg(feature="DEBUG")]
println!("After CORS proc: res = {res:#?}");

res
}
}




#[cfg(any(feature="rt_tokio",feature="rt_async-std"))]
#[cfg(feature="testing")]
#[cfg(test)]
mod test {
use crate::prelude::*;
use crate::testing::*;
use super::CORS;

#[crate::__rt__::test] async fn cors_headers() {
let t = Ohkami::with(
CORS::new("https://example.example"),
"/".GET(|| async {"Hello!"})
).test(); {
let req = TestRequest::GET("/");
let res = t.oneshot(req).await;

assert_eq!(res.status().code(), 200);
assert_eq!(res.text(), Some("Hello!"));

assert_eq!(res.header("Access-Control-Allow-Origin"), Some("https://example.example"));
assert_eq!(res.header("Access-Control-Allow-Credentials"), None);
assert_eq!(res.header("Access-Control-Expose-Headers"), None);
assert_eq!(res.header("Access-Control-Max-Age"), None);
assert_eq!(res.header("Access-Control-Allow-Methods"), None);
assert_eq!(res.header("Access-Control-Allow-Headers"), None);
assert_eq!(res.header("Vary"), None);
}

let t = Ohkami::with(
CORS::new("https://example.example")
.AllowCredentials()
.AllowHeaders(&["Content-Type", "X-Custom"]),
"/".GET(|| async {"Hello!"})
).test(); {
let req = TestRequest::GET("/");
let res = t.oneshot(req).await;

assert_eq!(res.status().code(), 200);
assert_eq!(res.text(), Some("Hello!"));

assert_eq!(res.header("Access-Control-Allow-Origin"), Some("https://example.example"));
assert_eq!(res.header("Access-Control-Allow-Credentials"), Some("true"));
assert_eq!(res.header("Access-Control-Expose-Headers"), None);
assert_eq!(res.header("Access-Control-Max-Age"), None);
assert_eq!(res.header("Access-Control-Allow-Methods"), None);
assert_eq!(res.header("Access-Control-Allow-Headers"), None);
assert_eq!(res.header("Vary"), None);
} {
let req = TestRequest::OPTIONS("/");
let res = t.oneshot(req).await;

assert_eq!(res.status().code(), 204);
assert_eq!(res.text(), None);

assert_eq!(res.header("Access-Control-Allow-Origin"), Some("https://example.example"));
assert_eq!(res.header("Access-Control-Allow-Credentials"), Some("true"));
assert_eq!(res.header("Access-Control-Expose-Headers"), None);
assert_eq!(res.header("Access-Control-Max-Age"), None);
assert_eq!(res.header("Access-Control-Allow-Methods"), Some("GET,PUT,POST,PATCH,DELETE,OPTIONS,HEAD"));
assert_eq!(res.header("Access-Control-Allow-Headers"), Some("Content-Type,X-Custom"));
assert_eq!(res.header("Vary"), Some("Access-Control-Request-Headers"));
}

let t = Ohkami::with(
CORS::new("*")
.AllowHeaders(&["Content-Type", "X-Custom"])
.MaxAge(1024),
"/".GET(|| async {"Hello!"})
).test(); {
let req = TestRequest::OPTIONS("/");
let res = t.oneshot(req).await;

assert_eq!(res.status().code(), 204);
assert_eq!(res.text(), None);

assert_eq!(res.header("Access-Control-Allow-Origin"), Some("*"));
assert_eq!(res.header("Access-Control-Allow-Credentials"), None);
assert_eq!(res.header("Access-Control-Expose-Headers"), None);
assert_eq!(res.header("Access-Control-Max-Age"), Some("1024"));
assert_eq!(res.header("Access-Control-Allow-Methods"), Some("GET,PUT,POST,PATCH,DELETE,OPTIONS,HEAD"));
assert_eq!(res.header("Access-Control-Allow-Headers"), Some("Content-Type,X-Custom"));
assert_eq!(res.header("Vary"), Some("Origin,Access-Control-Request-Headers"));
}
}
}
16 changes: 6 additions & 10 deletions ohkami/src/response/headers.rs
Original file line number Diff line number Diff line change
Expand Up @@ -195,23 +195,19 @@ macro_rules! Header {
pub const fn as_str(&self) -> &'static str {
unsafe {std::str::from_utf8_unchecked(self.as_bytes())}
}

pub const fn from_bytes(bytes: &[u8]) -> Option<Self> {
match bytes {
$(
$name_bytes => Some(Self::$konst),
)*
_ => None
}
}

#[inline(always)] const fn len(&self) -> usize {
match self {
$(
Self::$konst => $len,
)*
}
}

// Mainly used in tests
pub fn from_bytes(bytes: &[u8]) -> Option<Self> {
SERVER_HEADERS.into_iter()
.find(|h| h.as_bytes().eq_ignore_ascii_case(bytes))
}
}

impl<T: AsRef<[u8]>> PartialEq<T> for Header {
Expand Down
18 changes: 3 additions & 15 deletions ohkami/src/testing/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -188,21 +188,9 @@ impl TestResponse {
}

pub fn header(&self, name: &'static str) -> Option<&str> {
let name_bytes = name.split('-').map(|section| {
if section.eq_ignore_ascii_case("ETag") {
f!("ETag")
} else if section.eq_ignore_ascii_case("WebSocket") {
f!("WebSocket")
} else {
let mut section_chars = section.chars();
let first = section_chars.next().expect("Found `--` in header name").to_ascii_uppercase();
section_chars.fold(
String::from(first),
|mut section, ch| {section.push(ch); section}
)
}
}).collect::<String>();
self.0.headers.get(ResponseHeader::from_bytes(name_bytes.as_bytes())?)
ResponseHeader::from_bytes(name.as_bytes())
.and_then(|h| self.0.headers.get(h))
.or_else(|| self.0.headers.get_custom(name))
}

pub fn text(&self) -> Option<&str> {
Expand Down
Loading