diff --git a/Cargo.toml b/Cargo.toml index ed038d3..10b7916 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -1,6 +1,6 @@ [package] name = "mauth-client" -version = "0.5.0" +version = "0.6.0" authors = ["Mason Gup "] edition = "2021" documentation = "https://docs.rs/mauth-client/" @@ -26,17 +26,18 @@ dirs = "5" chrono = "0.4" tokio = { version = "1", features = ["fs"] } tower = { version = "0.4", optional = true } -axum = { version = ">= 0.7.2", optional = true } +axum = { version = ">= 0.8", optional = true } futures-core = { version = "0.3", optional = true } http = "1" bytes = { version = "1", optional = true } thiserror = "1" -mauth-core = "0.5" +mauth-core = "0.6" +tracing = { version = "0.1", optional = true } [dev-dependencies] tokio = { version = "1", features = ["rt-multi-thread", "macros"] } [features] -axum-service = ["tower", "futures-core", "axum", "bytes"] +axum-service = ["tower", "futures-core", "axum", "bytes", "tracing"] tracing-otel-26 = ["reqwest-tracing/opentelemetry_0_26"] tracing-otel-27 = ["reqwest-tracing/opentelemetry_0_27"] diff --git a/README.md b/README.md index a800925..ba29e9f 100644 --- a/README.md +++ b/README.md @@ -7,6 +7,8 @@ the MAuth protocol, and verify the responses. Usage example: release any code to Production or deploy in a Client-accessible environment without getting approval for the full stack used through the Architecture and Security groups. +## Outgoing Requests + ```no_run use mauth_client::MAuthInfo; use reqwest::Client; @@ -49,9 +51,136 @@ match client.get("https://www.example.com/").send().await { # } ``` +## Incoming Requests + The optional `axum-service` feature provides for a Tower Layer and Service that will authenticate incoming requests via MAuth V2 or V1 and provide to the lower layers a -validated app_uuid from the request via the ValidatedRequestDetails struct. +validated app_uuid from the request via the `ValidatedRequestDetails` struct. Note that +this feature now includes a `RequiredMAuthValidationLayer`, which will reject any +requests without a valid signature before they reach lower layers, and also a +`OptionalMAuthValidationLayer`, which lets all requests through, but only attaches a +`ValidatedRequestDetails` extension struct if there is a valid signature. When using this +layer, it is the responsiblity of the request handler to check for the extension and +reject requests that are not properly authorized. + +Note that `ValidatedRequestDetails` implements Axum's `FromRequestParts`, so you can +specify it bare in a request handler. This implementation includes returning a 401 +Unauthorized status code if the extension is not present. If you would like to return +a different response, or respond to the lack of the extension in another way, you can +use a more manual mechanism to check for the extension and decide how to proceed if it +is not present. + +### Examples for `RequiredMAuthValidationLayer` + +```no_run +# async fn run_server() { +use mauth_client::{ + axum_service::RequiredMAuthValidationLayer, + validate_incoming::ValidatedRequestDetails, +}; +use axum::{http::StatusCode, Router, routing::get, serve}; +use tokio::net::TcpListener; + +// If there is not a valid mauth signature, this function will never run at all, and +// the request will return an empty 401 Unauthorized +async fn foo() -> StatusCode { + StatusCode::OK +} + +// In addition to returning a 401 Unauthorized without running if there is not a valid +// MAuth signature, this also makes the validated requesting app UUID available to +// the function +async fn bar(details: ValidatedRequestDetails) -> StatusCode { + println!("Got a request from app with UUID: {}", details.app_uuid); + StatusCode::OK +} + +// This function will run regardless of whether or not there is a mauth signature +async fn baz() -> StatusCode { + StatusCode::OK +} + +// Attaching the baz route handler after the layer means the layer is not run for +// requests to that path, so no mauth checking will be performed for that route and +// any other routes attached after the layer +let router = Router::new() + .route("/foo", get(foo)) + .route("/bar", get(bar)) + .layer(RequiredMAuthValidationLayer::from_default_file().unwrap()) + .route("/baz", get(baz)); +let listener = TcpListener::bind("0.0.0.0:3000").await.unwrap(); +serve(listener, router).await.unwrap(); +# } +``` + +### Examples for `OptionalMAuthValidationLayer` + +```no_run +# async fn run_server() { +use mauth_client::{ + axum_service::OptionalMAuthValidationLayer, + validate_incoming::ValidatedRequestDetails, +}; +use axum::{http::StatusCode, Router, routing::get, serve}; +use tokio::net::TcpListener; + +// This request will run no matter what the authorization status is +async fn foo() -> StatusCode { + StatusCode::OK +} + +// If there is not a valid mauth signature, this function will never run at all, and +// the request will return an empty 401 Unauthorized +async fn bar(_: ValidatedRequestDetails) -> StatusCode { + StatusCode::OK +} + +// In addition to returning a 401 Unauthorized without running if there is not a valid +// MAuth signature, this also makes the validated requesting app UUID available to +// the function +async fn baz(details: ValidatedRequestDetails) -> StatusCode { + println!("Got a request from app with UUID: {}", details.app_uuid); + StatusCode::OK +} + +// This request will run whether or not there is a valid mauth signature, but the Option +// provided can be used to tell you whether there was a valid signature, so you can +// implement things like multiple possible types of authentication or behavior other than +// a 401 return if there is no authentication +async fn bam(optional_details: Option) -> StatusCode { + match optional_details { + Some(details) => println!("Got a request from app with UUID: {}", details.app_uuid), + None => println!("Got a request without a valid mauth signature"), + } + StatusCode::OK +} + +let router = Router::new() + .route("/foo", get(foo)) + .route("/bar", get(bar)) + .route("/baz", get(baz)) + .route("/bam", get(bam)) + .layer(OptionalMAuthValidationLayer::from_default_file().unwrap()); +let listener = TcpListener::bind("0.0.0.0:3000").await.unwrap(); +serve(listener, router).await.unwrap(); +# } +``` + +### Error Handling + +Both the `RequiredMAuthValidationLayer` and the `OptionalMAuthValidationLayer` layers will +log errors encountered via `tracing` under the `mauth_client::validate_incoming` target. + +The Required layer returns the 401 response immediately, so there is no convenient way to +retrieve the error in order to do anything more sophisticated with it. + +The Optional layer, in addition to loging the error, will also add the `MAuthValidationError` +to the request extensions. If desired, any request handlers or middlewares can retrieve it +from there in order to take further actions based on the error type. This error type also +implements Axum's `OptionalFromRequestParts`, so you can more easily retrieve it using +`Option` anywhere that supports extractors. + +### OpenTelemetry Integration There are also optional features `tracing-otel-26` and `tracing-otel-27` that pair with the `axum-service` feature to ensure that any outgoing requests for credentials that take diff --git a/src/axum_service.rs b/src/axum_service.rs index 04cf6ae..4cdf1a6 100644 --- a/src/axum_service.rs +++ b/src/axum_service.rs @@ -1,11 +1,19 @@ //! Structs and impls related to providing a Tower Service and Layer to verify incoming requests -use axum::extract::Request; +use axum::{ + body::Body, + extract::{FromRequestParts, OptionalFromRequestParts, Request}, + response::IntoResponse, +}; use futures_core::future::BoxFuture; +use http::{request::Parts, Response, StatusCode}; +use std::convert::Infallible; use std::error::Error; use std::task::{Context, Poll}; use tower::{Layer, Service}; +use tracing::error; +use crate::validate_incoming::{MAuthValidationError, ValidatedRequestDetails}; use crate::{ config::{ConfigFileSection, ConfigReadError}, MAuthInfo, @@ -14,24 +22,25 @@ use crate::{ /// This is a Tower Service which validates that incoming requests have a valid /// MAuth signature. It only passes the request down to the next layer if the /// signature is valid, otherwise it returns an appropriate error to the caller. -pub struct MAuthValidationService { +pub struct RequiredMAuthValidationService { mauth_info: MAuthInfo, config_info: ConfigFileSection, service: S, } -impl Service for MAuthValidationService +impl Service for RequiredMAuthValidationService where S: Service + Send + Clone + 'static, S::Future: Send + 'static, S::Error: Into>, + S::Response: Into>, { - type Response = S::Response; - type Error = Box; + type Response = Response; + type Error = S::Error; type Future = BoxFuture<'static, Result>; fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll> { - self.service.poll_ready(cx).map_err(|e| e.into()) + self.service.poll_ready(cx) } fn call(&mut self, request: Request) -> Self::Future { @@ -39,18 +48,111 @@ where Box::pin(async move { match cloned.mauth_info.validate_request(request).await { Ok(valid_request) => match cloned.service.call(valid_request).await { - Ok(response) => Ok(response), - Err(err) => Err(err.into()), + Ok(response) => Ok(response.into()), + Err(err) => Err(err), }, - Err(err) => Err(Box::new(err) as Box), + Err(err) => { + error!( + error = ?err, + "Failed to validate MAuth signature, rejecting request" + ); + Ok(StatusCode::UNAUTHORIZED.into_response()) + } } }) } } -impl Clone for MAuthValidationService { +impl Clone for RequiredMAuthValidationService { + fn clone(&self) -> Self { + RequiredMAuthValidationService { + // unwrap is safe because we validated the config_info before constructing the layer + mauth_info: MAuthInfo::from_config_section(&self.config_info).unwrap(), + config_info: self.config_info.clone(), + service: self.service.clone(), + } + } +} + +/// This is a Tower Layer which applies the RequiredMAuthValidationService on top of the +/// service provided to it. +#[derive(Clone)] +pub struct RequiredMAuthValidationLayer { + config_info: ConfigFileSection, +} + +impl Layer for RequiredMAuthValidationLayer { + type Service = RequiredMAuthValidationService; + + fn layer(&self, service: S) -> Self::Service { + RequiredMAuthValidationService { + // unwrap is safe because we validated the config_info before constructing the layer + mauth_info: MAuthInfo::from_config_section(&self.config_info).unwrap(), + config_info: self.config_info.clone(), + service, + } + } +} + +impl RequiredMAuthValidationLayer { + /// Construct a RequiredMAuthValidationLayer based on the configuration options in the file + /// found in the default location. + pub fn from_default_file() -> Result { + let config_info = MAuthInfo::config_section_from_default_file()?; + // Generate a MAuthInfo and then drop it to validate that it works, + // making it safe to use `unwrap` in the service constructor. + MAuthInfo::from_config_section(&config_info)?; + Ok(RequiredMAuthValidationLayer { config_info }) + } + + /// Construct a RequiredMAuthValidationLayer based on the configuration options in a manually + /// created or parsed ConfigFileSection. + pub fn from_config_section(config_info: ConfigFileSection) -> Result { + MAuthInfo::from_config_section(&config_info)?; + Ok(RequiredMAuthValidationLayer { config_info }) + } +} + +/// This is a Tower Service which validates that incoming requests have a valid +/// MAuth signature. Unlike the Required service, if this service is not able to +/// find or validate a signature, it passes the request down to the lower layers +/// anyways. This means that it is the responsibility of the request handler to +/// check for the `ValidatedRequestDetails` extension to determine if the request +/// has a valid signature. It also means that this service is safe to attach to +/// the whole application, even if some requests are not validated at all or may +/// be validated in a different way. +pub struct OptionalMAuthValidationService { + mauth_info: MAuthInfo, + config_info: ConfigFileSection, + service: S, +} + +impl Service for OptionalMAuthValidationService +where + S: Service + Send + Clone + 'static, + S::Future: Send + 'static, + S::Error: Into>, +{ + type Response = S::Response; + type Error = S::Error; + type Future = BoxFuture<'static, Result>; + + fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll> { + self.service.poll_ready(cx) + } + + fn call(&mut self, request: Request) -> Self::Future { + let mut cloned = self.clone(); + Box::pin(async move { + let processed_request = cloned.mauth_info.validate_request_optionally(request).await; + cloned.service.call(processed_request).await + }) + } +} + +impl Clone for OptionalMAuthValidationService { fn clone(&self) -> Self { - MAuthValidationService { + OptionalMAuthValidationService { // unwrap is safe because we validated the config_info before constructing the layer mauth_info: MAuthInfo::from_config_section(&self.config_info).unwrap(), config_info: self.config_info.clone(), @@ -59,18 +161,18 @@ impl Clone for MAuthValidationService { } } -/// This is a Tower Layer which applies the MAuthValidationService on top of the +/// This is a Tower Layer which applies the OptionalMAuthValidationService on top of the /// service provided to it. #[derive(Clone)] -pub struct MAuthValidationLayer { +pub struct OptionalMAuthValidationLayer { config_info: ConfigFileSection, } -impl Layer for MAuthValidationLayer { - type Service = MAuthValidationService; +impl Layer for OptionalMAuthValidationLayer { + type Service = OptionalMAuthValidationService; fn layer(&self, service: S) -> Self::Service { - MAuthValidationService { + OptionalMAuthValidationService { // unwrap is safe because we validated the config_info before constructing the layer mauth_info: MAuthInfo::from_config_section(&self.config_info).unwrap(), config_info: self.config_info.clone(), @@ -79,21 +181,64 @@ impl Layer for MAuthValidationLayer { } } -impl MAuthValidationLayer { - /// Construct a MAuthValidationLayer based on the configuration options in the file +impl OptionalMAuthValidationLayer { + /// Construct an OptionalMAuthValidationLayer based on the configuration options in the file /// found in the default location. pub fn from_default_file() -> Result { let config_info = MAuthInfo::config_section_from_default_file()?; // Generate a MAuthInfo and then drop it to validate that it works, // making it safe to use `unwrap` in the service constructor. MAuthInfo::from_config_section(&config_info)?; - Ok(MAuthValidationLayer { config_info }) + Ok(OptionalMAuthValidationLayer { config_info }) } - /// Construct a MAuthValidationLayer based on the configuration options in a manually + /// Construct an OptionalMAuthValidationLayer based on the configuration options in a manually /// created or parsed ConfigFileSection. pub fn from_config_section(config_info: ConfigFileSection) -> Result { MAuthInfo::from_config_section(&config_info)?; - Ok(MAuthValidationLayer { config_info }) + Ok(OptionalMAuthValidationLayer { config_info }) + } +} + +impl FromRequestParts for ValidatedRequestDetails +where + S: Send + Sync, +{ + type Rejection = StatusCode; + + async fn from_request_parts(parts: &mut Parts, _state: &S) -> Result { + parts + .extensions + .get::() + .cloned() + .ok_or(StatusCode::UNAUTHORIZED) + } +} + +impl OptionalFromRequestParts for ValidatedRequestDetails +where + S: Send + Sync, +{ + type Rejection = Infallible; + + async fn from_request_parts( + parts: &mut Parts, + _state: &S, + ) -> Result, Self::Rejection> { + Ok(parts.extensions.get::().cloned()) + } +} + +impl OptionalFromRequestParts for MAuthValidationError +where + S: Send + Sync, +{ + type Rejection = Infallible; + + async fn from_request_parts( + parts: &mut Parts, + _state: &S, + ) -> Result, Self::Rejection> { + Ok(parts.extensions.get::().cloned()) } } diff --git a/src/reqwest_middleware.rs b/src/reqwest_middleware.rs index 71f6c0d..80c6cb8 100644 --- a/src/reqwest_middleware.rs +++ b/src/reqwest_middleware.rs @@ -6,7 +6,6 @@ use crate::{sign_outgoing::SigningError, MAuthInfo}; #[async_trait::async_trait] impl Middleware for MAuthInfo { - #[must_use] async fn handle( &self, mut req: Request, diff --git a/src/validate_incoming.rs b/src/validate_incoming.rs index ac02f5d..dbfc870 100644 --- a/src/validate_incoming.rs +++ b/src/validate_incoming.rs @@ -1,7 +1,10 @@ use crate::{MAuthInfo, CLIENT, PUBKEY_CACHE}; +use axum::extract::Request; +use bytes::Bytes; use chrono::prelude::*; use mauth_core::verifier::Verifier; use thiserror::Error; +use tracing::error; use uuid::Uuid; /// This struct holds the app UUID for a validated request. It is meant to be used with the @@ -15,11 +18,16 @@ pub struct ValidatedRequestDetails { pub app_uuid: Uuid, } +const MAUTH_V1_SIGNATURE_HEADER: &str = "X-MWS-Authentication"; +const MAUTH_V2_SIGNATURE_HEADER: &str = "MCC-Authentication"; +const MAUTH_V1_TIMESTAMP_HEADER: &str = "X-MWS-Time"; +const MAUTH_V2_TIMESTAMP_HEADER: &str = "MCC-Time"; + impl MAuthInfo { pub(crate) async fn validate_request( &self, - req: axum::extract::Request, - ) -> Result { + req: Request, + ) -> Result { let (mut parts, body) = req.into_parts(); let body_bytes = axum::body::to_bytes(body, usize::MAX) .await @@ -30,7 +38,7 @@ impl MAuthInfo { app_uuid: host_app_uuid, }); let new_body = axum::body::Body::from(body_bytes); - let new_request = axum::extract::Request::from_parts(parts, new_body); + let new_request = Request::from_parts(parts, new_body); Ok(new_request) } Err(err) => { @@ -41,7 +49,7 @@ impl MAuthInfo { app_uuid: host_app_uuid, }); let new_body = axum::body::Body::from(body_bytes); - let new_request = axum::extract::Request::from_parts(parts, new_body); + let new_request = Request::from_parts(parts, new_body); Ok(new_request) } Err(err) => Err(err), @@ -53,6 +61,64 @@ impl MAuthInfo { } } + pub(crate) async fn validate_request_optionally(&self, req: Request) -> Request { + let (mut parts, body) = req.into_parts(); + if parts.headers.contains_key(MAUTH_V2_SIGNATURE_HEADER) + || parts.headers.contains_key(MAUTH_V1_SIGNATURE_HEADER) + { + // By my reading of the code for this it should never fail, since we are passing + // MAX for the limit. But just to be safe, we will log the error and proceed with + // an empty body just in case instead of unwrapping. This would cause the body to + // be unavailable to the lower layers, but they would probably also fail to get it + // anyways since we just did here. + let body_bytes = match axum::body::to_bytes(body, usize::MAX).await { + Ok(bytes) => bytes, + Err(error) => { + error!( + ?error, + "Failed to retrieve request body, continuing with empty body" + ); + Bytes::new() + } + }; + + match self.validate_request_v2(&parts, &body_bytes).await { + Ok(host_app_uuid) => { + parts.extensions.insert(ValidatedRequestDetails { + app_uuid: host_app_uuid, + }); + } + Err(error_v2) => { + if self.allow_v1_auth { + match self.validate_request_v1(&parts, &body_bytes).await { + Ok(host_app_uuid) => { + parts.extensions.insert(ValidatedRequestDetails { + app_uuid: host_app_uuid, + }); + } + Err(error_v1) => { + error!( + ?error_v2, + ?error_v1, + "Error attempting to validate MAuth signatures" + ); + parts.extensions.insert(error_v1); + } + } + } else { + error!(?error_v2, "Error attempting to validate MAuth V2 signature"); + parts.extensions.insert(error_v2); + } + } + } + + let new_body = axum::body::Body::from(body_bytes); + Request::from_parts(parts, new_body) + } else { + Request::from_parts(parts, body) + } + } + async fn validate_request_v2( &self, req: &http::request::Parts, @@ -61,7 +127,7 @@ impl MAuthInfo { //retrieve and parse auth string let sig_header = req .headers - .get("MCC-Authentication") + .get(MAUTH_V2_SIGNATURE_HEADER) .ok_or(MAuthValidationError::NoSig)? .to_str() .map_err(|_| MAuthValidationError::InvalidSignature)?; @@ -70,7 +136,7 @@ impl MAuthInfo { //retrieve and validate timestamp let ts_str = req .headers - .get("MCC-Time") + .get(MAUTH_V2_TIMESTAMP_HEADER) .ok_or(MAuthValidationError::NoTime)? .to_str() .map_err(|_| MAuthValidationError::InvalidTime)?; @@ -107,7 +173,7 @@ impl MAuthInfo { //retrieve and parse auth string let sig_header = req .headers - .get("X-MWS-Authentication") + .get(MAUTH_V1_SIGNATURE_HEADER) .ok_or(MAuthValidationError::NoSig)? .to_str() .map_err(|_| MAuthValidationError::InvalidSignature)?; @@ -116,7 +182,7 @@ impl MAuthInfo { //retrieve and validate timestamp let ts_str = req .headers - .get("X-MWS-Time") + .get(MAUTH_V1_TIMESTAMP_HEADER) .ok_or(MAuthValidationError::NoTime)? .to_str() .map_err(|_| MAuthValidationError::InvalidTime)?; @@ -218,7 +284,7 @@ impl MAuthInfo { } /// All of the possible errors that can take place when attempting to verify a response signature -#[derive(Debug, Error)] +#[derive(Debug, Error, Clone)] pub enum MAuthValidationError { /// The timestamp of the response was either invalid or outside of the permitted /// range