Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Axum updates #2738

Closed
wants to merge 4 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
383 changes: 272 additions & 111 deletions Cargo.lock

Large diffs are not rendered by default.

31 changes: 23 additions & 8 deletions Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -89,9 +89,9 @@ serde_with = "3.7.0"
argon2 = { version = "0.5.3", features = ["alloc"] }
async-recursion = "1.1.0"
async-trait = "^0.1.78"
axum = { version = "0.6.20", features = [
axum = { version = "0.7.5", features = [
"form",
"headers",
# "headers",
"http2",
"json",
"macros",
Expand All @@ -100,7 +100,11 @@ axum = { version = "0.6.20", features = [
"query",
"tracing",
] }
axum-csp = { version = "0.0.5" }
axum-auth = "0.4.1"
axum-csp = { path = "../axum-csp" }
axum-extra = "0.9.3"
axum-macros = "0.3.8"
axum-server = "0.5.1"
base32 = "^0.4.0"
base64 = "^0.21.7"
base64urlsafedata = "0.5.0"
Expand Down Expand Up @@ -129,9 +133,10 @@ gloo = "^0.8.1"
gloo-utils = "0.2.0"
hashbrown = { version = "0.14.3", features = ["serde", "inline-more", "ahash"] }
hex = "^0.4.3"
http = "0.2.12"
hyper = { version = "0.14.28", features = ["full"] }
hyper-tls = "0.5.0"
http = "1.1.0"
hyper = { version = "1.2.0", features = ["full"] }
hyper-util = { version = "0.1.3", features = ["server"] }
hyper-tls = "0.6.0"
idlset = "^0.2.4"
image = { version = "0.24.9", default-features = false, features = [
"gif",
Expand Down Expand Up @@ -210,10 +215,12 @@ time = { version = "^0.3.34", features = ["formatting", "local-offset"] }
tikv-jemallocator = "0.5"

tokio = "^1.36.0"
tokio-openssl = "^0.6.4"
tokio-openssl = "0.6.4"
tokio-util = "^0.7.10"

toml = "^0.5.11"
tower = "0.4.13"
tower-http = "0.5.2"
tracing = { version = "^0.1.40", features = [
"max_level_trace",
"release_max_level_debug",
Expand All @@ -225,7 +232,7 @@ tracing-forest = "^0.1.6"
url = "^2.5.0"
urlencoding = "2.1.3"
utoipa = "4.2.0"
utoipa-swagger-ui = "4.0.0"
utoipa-swagger-ui = "6.0.0"
uuid = "^1.8.0"

wasm-bindgen = "^0.2.92"
Expand All @@ -250,3 +257,11 @@ yew-router = "^0.17.0"
zxcvbn = "^2.2.2"

nonempty = "0.8.1"
assert_cmd = "2.0.14"
escargot = "0.5.10"
fantoccini = "0.19.3"
gethostname = "0.4.3"
gloo-timers = "0.3.0"
jsonschema = "0.17.1"
petgraph = "0.6.4"
wasm-timer = "0.2.5"
2 changes: 1 addition & 1 deletion libs/sketching/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ test = false
doctest = false

[dependencies]
gethostname = "0.4.3"
gethostname = { workspace = true }
num_enum = { workspace = true }
opentelemetry = { workspace = true, features = ["metrics", "rt-tokio"] }
opentelemetry-otlp = { workspace = true, default-features = false, features = [
Expand Down
14 changes: 8 additions & 6 deletions server/core/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -17,11 +17,12 @@ doctest = false

[dependencies]
async-trait = { workspace = true }
axum = { workspace = true }
axum-auth = "0.4.1"
axum = { workspace = true, features = ["form"] }
axum-auth = { workspace = true }
axum-csp = { workspace = true }
axum-macros = "0.3.8"
axum-server = { version = "0.5.1", features = ["tls-openssl"] }
axum-extra = { workspace = true, features = ["typed-header"] }
axum-macros = { workspace = true }
axum-server = { workspace = true, features = ["tls-openssl"] }
bytes = { workspace = true }
chrono = { workspace = true }
compact_jwt = { workspace = true }
Expand All @@ -32,6 +33,7 @@ futures-util = { workspace = true }
hashbrown = { workspace = true }
http = { workspace = true }
hyper = { workspace = true }
hyper-util = { workspace = true }
kanidm_proto = { workspace = true }
kanidm_utils_users = { workspace = true }
kanidmd_lib = { workspace = true }
Expand All @@ -51,8 +53,8 @@ tokio = { workspace = true, features = ["net", "sync", "io-util", "macros"] }
tokio-openssl = { workspace = true }
tokio-util = { workspace = true, features = ["codec"] }
toml = { workspace = true }
tower = { version = "0.4.13", features = ["tokio-stream", "tracing"] }
tower-http = { version = "0.4.4", features = [
tower = { workspace = true, features = ["tokio-stream", "tracing"] }
tower-http = { workspace = true, features = [
"compression-gzip",
"fs",
"tokio",
Expand Down
48 changes: 28 additions & 20 deletions server/core/src/https/extractors/mod.rs
Original file line number Diff line number Diff line change
@@ -1,15 +1,20 @@
use axum::{
async_trait,
extract::connect_info::{ConnectInfo, Connected},
extract::FromRequestParts,
extract::{
connect_info::{ConnectInfo, Connected},
FromRequestParts,
},
http::{
header::HeaderName, header::AUTHORIZATION as AUTHORISATION, request::Parts, StatusCode,
header::{HeaderName, AUTHORIZATION as AUTHORISATION},
request::Parts,
StatusCode,
},
RequestPartsExt,
};
use hyper::server::conn::AddrStream;

use kanidm_proto::constants::X_FORWARDED_FOR;
use kanidmd_lib::prelude::{ClientAuthInfo, ClientCertInfo, Source};
use tokio::net::TcpStream;

use compact_jwt::JwsCompact;
use std::str::FromStr;
Expand Down Expand Up @@ -82,7 +87,7 @@ pub struct VerifiedClientInformation(pub ClientAuthInfo);

#[async_trait]
impl FromRequestParts<ServerState> for VerifiedClientInformation {
type Rejection = (StatusCode, &'static str);
type Rejection = &'static str;

#[instrument(level = "debug", skip(state))]
async fn from_request_parts(
Expand All @@ -94,10 +99,11 @@ impl FromRequestParts<ServerState> for VerifiedClientInformation {
.await
.map_err(|_| {
error!("Connect info contains invalid data");
(
StatusCode::BAD_REQUEST,
"connect info contains invalid data",
)
// (
// StatusCode::BAD_REQUEST,
"connect info contains invalid data"
// ,
// )
})?;

let ip_addr = if state.trust_x_forward_for {
Expand All @@ -109,17 +115,19 @@ impl FromRequestParts<ServerState> for VerifiedClientInformation {
// Split on an optional comma, return the first result.
s.split(',').next().unwrap_or(s))
.map_err(|_| {
(
StatusCode::BAD_REQUEST,
"X-Forwarded-For contains invalid data",
)
// (
// StatusCode::BAD_REQUEST,
"X-Forwarded-For contains invalid data"
// ,
// )
})?;

first.parse::<IpAddr>().map_err(|_| {
(
StatusCode::BAD_REQUEST,
"X-Forwarded-For contains invalid ip addr",
)
// (
// StatusCode::BAD_REQUEST,
"X-Forwarded-For contains invalid ip addr"
// ,
// )
})?
} else {
addr.ip()
Expand Down Expand Up @@ -181,10 +189,10 @@ impl Connected<ClientConnInfo> for ClientConnInfo {
}
}

impl<'a> Connected<&'a AddrStream> for ClientConnInfo {
fn connect_info(target: &'a AddrStream) -> Self {
impl<'a> Connected<&'a TcpStream> for ClientConnInfo {
fn connect_info(target: &'a TcpStream) -> Self {
ClientConnInfo {
addr: target.remote_addr(),
addr: target.peer_addr().expect("Failed to get peer addr"),
client_cert: None,
}
}
Expand Down
7 changes: 4 additions & 3 deletions server/core/src/https/middleware/caching.rs
Original file line number Diff line number Diff line change
@@ -1,12 +1,13 @@
use axum::{
headers::{CacheControl, HeaderMapExt},
body::Body,
http::{header, HeaderValue, Request},
middleware::Next,
response::Response,
};
use axum_extra::headers::{CacheControl, HeaderMapExt};

/// Adds `no-cache max-age=0` to the response headers.
pub async fn dont_cache_me<B>(request: Request<B>, next: Next<B>) -> Response {
pub async fn dont_cache_me(request: Request<Body>, next: Next) -> Response {
let mut response = next.run(request).await;
response.headers_mut().insert(
header::CACHE_CONTROL,
Expand All @@ -20,7 +21,7 @@ pub async fn dont_cache_me<B>(request: Request<B>, next: Next<B>) -> Response {
}

/// Adds a cache control header of 300 seconds to the response headers.
pub async fn cache_me<B>(request: Request<B>, next: Next<B>) -> Response {
pub async fn cache_me(request: Request<Body>, next: Next) -> Response {
let mut response = next.run(request).await;
let cache_header = CacheControl::new()
.with_max_age(std::time::Duration::from_secs(300))
Expand Down
3 changes: 2 additions & 1 deletion server/core/src/https/middleware/hsts_header.rs
Original file line number Diff line number Diff line change
@@ -1,10 +1,11 @@
use axum::body::Body;
use axum::http::{header, HeaderValue, Request};
use axum::middleware::Next;
use axum::response::Response;

const HSTS_HEADER: &str = "max-age=86400";

pub async fn strict_transport_security_layer<B>(request: Request<B>, next: Next<B>) -> Response {
pub async fn strict_transport_security_layer(request: Request<Body>, next: Next) -> Response {
// wait for the middleware to come back
let mut response = next.run(request).await;

Expand Down
7 changes: 4 additions & 3 deletions server/core/src/https/middleware/mod.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
use axum::{
body::Body,
http::{HeaderValue, Request},
middleware::Next,
response::Response,
Expand All @@ -15,7 +16,7 @@ pub(crate) mod security_headers;
const KANIDM_VERSION: &str = env!("CARGO_PKG_VERSION");

/// Injects a header into the response with "X-KANIDM-VERSION" matching the version of the package.
pub async fn version_middleware<B>(request: Request<B>, next: Next<B>) -> Response {
pub async fn version_middleware(request: Request<Body>, next: Next) -> Response {
let mut response = next.run(request).await;
response
.headers_mut()
Expand All @@ -26,7 +27,7 @@ pub async fn version_middleware<B>(request: Request<B>, next: Next<B>) -> Respon
#[cfg(any(test, debug_assertions))]
/// This is a debug middleware to ensure that /v1/ endpoints only return JSON
#[instrument(level = "trace", name = "are_we_json_yet", skip_all)]
pub async fn are_we_json_yet<B>(request: Request<B>, next: Next<B>) -> Response {
pub async fn are_we_json_yet(request: Request<Body>, next: Next) -> Response {
let uri = request.uri().path().to_string();

let response = next.run(request).await;
Expand Down Expand Up @@ -54,7 +55,7 @@ pub struct KOpId {

/// This runs at the start of the request, adding an extension with `KOpId` which has useful things inside it.
#[instrument(level = "trace", name = "kopid_middleware", skip_all)]
pub async fn kopid_middleware<B>(mut request: Request<B>, next: Next<B>) -> Response {
pub async fn kopid_middleware(mut request: Request<Body>, next: Next) -> Response {
// generate the event ID
let eventid = sketching::tracing_forest::id();

Expand Down
7 changes: 4 additions & 3 deletions server/core/src/https/middleware/security_headers.rs
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
use axum::body::Body;
use axum::extract::State;
use axum::http::header;
use axum::http::HeaderValue;
Expand All @@ -10,10 +11,10 @@ use crate::https::ServerState;
const PERMISSIONS_POLICY_VALUE: &str = "fullscreen=(), geolocation=()";
const X_CONTENT_TYPE_OPTIONS_VALUE: &str = "nosniff";

pub async fn security_headers_layer<B>(
pub async fn security_headers_layer(
State(state): State<ServerState>,
request: Request<B>,
next: Next<B>,
request: Request<Body>,
next: Next,
) -> Response {
// wait for the middleware to come back
let mut response = next.run(request).await;
Expand Down