Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

refactor auth middleware for supporting bearer, cookie and query #560

Merged
merged 3 commits into from
May 1, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
66 changes: 66 additions & 0 deletions docs-site/content/docs/the-app/controller.md
Original file line number Diff line number Diff line change
Expand Up @@ -233,6 +233,72 @@ impl Hooks for App {

# Middleware

### Authentication
In the `Loco` framework, middleware plays a crucial role in authentication. `Loco` supports various authentication methods, including JSON Web Token (JWT) and API Key authentication. This section outlines how to configure and use authentication middleware in your application.

#### JSON Web Token (JWT)

##### Configuration
By default, Loco uses Bearer authentication for JWT. However, you can customize this behavior in the configuration file under the auth.jwt section.
* *Bearer Authentication:* Keep the configuration blank or explicitly set it as follows:
```yaml
# Authentication Configuration
auth:
# JWT authentication
jwt:
location: Bearer
...
```
* *Cookie Authentication:* Configure the location from which to extract the token and specify the cookie name:
```yaml
# Authentication Configuration
auth:
# JWT authentication
jwt:
location:
from: Cookie
name: token
...
```
* *Query Parameter Authentication:* Specify the location and name of the query parameter:
```yaml
# Authentication Configuration
auth:
# JWT authentication
jwt:
location:
from: Query
name: token
...
```

##### Usage
In your controller parameters, use `auth::JWT` for authentication. This triggers authentication validation based on the configured settings.
```rust
use loco_rs::prelude::*;

async fn current(
auth: auth::JWT,
State(_ctx): State<AppContext>,
) -> Result<Response> {
// Your implementation here
}
```
Additionally, you can fetch the current user by replacing auth::JWT with `auth::ApiToken<users::Model>`.

#### API Key
For API Key authentication, use auth::ApiToken. This middleware validates the API key against the user database record and loads the corresponding user into the authentication parameter.
```rust
use loco_rs::prelude::*;

async fn current(
auth: auth::ApiToken<users::Model>,
State(_ctx): State<AppContext>,
) -> Result<Response> {
// Your implementation here
}
```

## Compression

`Loco` leverages [CompressionLayer](https://docs.rs/tower-http/0.5.0/tower_http/compression/index.html) to enable a `one click` solution.
Expand Down
1 change: 1 addition & 0 deletions examples/demo/Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

1 change: 1 addition & 0 deletions examples/demo/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@ sea-orm = { version = "1.0.0-rc.1", features = [

axum = { version = "0.7.1", features = ["multipart"] }
axum_session = { version = "0.10.1", default-features = false }
axum-extra = { version = "0.9", features = ["cookie"] }

include_dir = "0.7"
uuid = { version = "1.6.0", features = ["v4"] }
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
---
source: tests/requests/user.rs
expression: "(response.status_code(), response.text())"
---
(
200,
"{\"pid\":\"PID\",\"name\":\"loco\",\"email\":\"test@loco.com\"}",
)
18 changes: 18 additions & 0 deletions src/config.rs
Original file line number Diff line number Diff line change
Expand Up @@ -217,12 +217,30 @@ pub struct Auth {
/// JWT configuration structure.
#[derive(Debug, Clone, Deserialize, Serialize)]
pub struct JWT {
/// The location where JWT tokens are expected to be found during
/// authentication.
pub location: Option<JWTLocation>,
/// The secret key For JWT token
pub secret: String,
/// The expiration time for authentication tokens
pub expiration: u64,
}

/// Defines the authentication mechanism for middleware.
///
/// This enum represents various ways to authenticate using JSON Web Tokens
/// (JWT) within middleware.
#[derive(Debug, Clone, Deserialize, Serialize)]
#[serde(tag = "from")]
pub enum JWTLocation {
/// Authenticate using a Bearer token.
Bearer,
/// Authenticate using a token passed as a query parameter.
Query { name: String },
/// Authenticate using a token stored in a cookie.
Cookie { name: String },
}

/// Server configuration structure.
///
/// Example (development):
Expand Down
146 changes: 129 additions & 17 deletions src/controller/middleware/auth.rs
Original file line number Diff line number Diff line change
Expand Up @@ -19,15 +19,26 @@
//! format::json(TestResponse{ pid: auth.claims.pid})
//! }
//! ```
use std::collections::HashMap;

use async_trait::async_trait;
use axum::{
extract::{FromRef, FromRequestParts},
extract::{FromRef, FromRequestParts, Query},
http::{request::Parts, HeaderMap},
};
use axum_extra::extract::cookie;
use serde::{Deserialize, Serialize};

use crate::{app::AppContext, auth, errors::Error, model::Authenticable};
use crate::{
app::AppContext, auth, config::JWT as JWTConfig, errors::Error, model::Authenticable,
Result as LocoResult,
};

// ---------------------------------------
//
// JWT Auth extractor
//
// ---------------------------------------

// Define constants for token prefix and authorization header
const TOKEN_PREFIX: &str = "Bearer ";
Expand All @@ -52,16 +63,15 @@ where
type Rejection = Error;

async fn from_request_parts(parts: &mut Parts, state: &S) -> Result<Self, Error> {
let token = extract_token_from_header(&parts.headers)
.map_err(|e| Error::Unauthorized(e.to_string()))?;
let ctx: AppContext = AppContext::from_ref(state);

let state: AppContext = AppContext::from_ref(state);
let token = extract_token(get_jwt_from_config(&ctx)?, parts)?;

let jwt_secret = state.config.get_jwt_config()?;
let jwt_secret = ctx.config.get_jwt_config()?;

match auth::jwt::JWT::new(&jwt_secret.secret).validate(&token) {
Ok(claims) => {
let user = T::find_by_claims_key(&state.db, &claims.claims.pid)
let user = T::find_by_claims_key(&ctx.db, &claims.claims.pid)
.await
.map_err(|_| Error::Unauthorized("token is not valid".to_string()))?;
Ok(Self {
Expand Down Expand Up @@ -93,12 +103,11 @@ where
type Rejection = Error;

async fn from_request_parts(parts: &mut Parts, state: &S) -> Result<Self, Error> {
let token = extract_token_from_header(&parts.headers)
.map_err(|e| Error::Unauthorized(e.to_string()))?;
let ctx: AppContext = AppContext::from_ref(state); // change to ctx

let state: AppContext = AppContext::from_ref(state);
let token = extract_token(get_jwt_from_config(&ctx)?, parts)?;

let jwt_secret = state.config.get_jwt_config()?;
let jwt_secret = ctx.config.get_jwt_config()?;

match auth::jwt::JWT::new(&jwt_secret.secret).validate(&token) {
Ok(claims) => Ok(Self {
Expand All @@ -111,20 +120,77 @@ where
}
}

/// extract JWT token from context configuration
///
/// # Errors
/// Return an error when JWT token not configured
fn get_jwt_from_config(ctx: &AppContext) -> LocoResult<&JWTConfig> {
ctx.config
.auth
.as_ref()
.ok_or_else(|| Error::string("auth not configured"))?
.jwt
.as_ref()
.ok_or_else(|| Error::string("JWT token not configured"))
}
/// extract token from the configured jwt location settings
fn extract_token(jwt_config: &JWTConfig, parts: &Parts) -> LocoResult<String> {
#[allow(clippy::match_wildcard_for_single_variants)]
match jwt_config
.location
.as_ref()
.unwrap_or(&crate::config::JWTLocation::Bearer)
{
crate::config::JWTLocation::Query { name } => extract_token_from_query(name, parts),
crate::config::JWTLocation::Cookie { name } => extract_token_from_cookie(name, parts),
crate::config::JWTLocation::Bearer => extract_token_from_header(&parts.headers)
.map_err(|e| Error::Unauthorized(e.to_string())),
}
}
/// Function to extract a token from the authorization header
///
/// # Errors
///
/// When token is not valid or out found
pub fn extract_token_from_header(headers: &HeaderMap) -> eyre::Result<String> {
pub fn extract_token_from_header(headers: &HeaderMap) -> LocoResult<String> {
Ok(headers
.get(AUTH_HEADER)
.ok_or_else(|| eyre::eyre!("header {} token not found", AUTH_HEADER))?
.to_str()?
.ok_or_else(|| Error::Unauthorized(format!("header {AUTH_HEADER} token not found")))?
.to_str()
.map_err(|err| Error::Unauthorized(err.to_string()))?
.strip_prefix(TOKEN_PREFIX)
.ok_or_else(|| eyre::eyre!("error strip {} value", AUTH_HEADER))?
.ok_or_else(|| Error::Unauthorized(format!("error strip {AUTH_HEADER} value")))?
.to_string())
}

/// Extract a token value from cookie
///
/// # Errors
/// when token value from cookie is not found
pub fn extract_token_from_cookie(name: &str, parts: &Parts) -> LocoResult<String> {
// LogoResult
let jar: cookie::CookieJar = cookie::CookieJar::from_headers(&parts.headers);
Ok(jar
.get(name)
.ok_or(Error::Unauthorized("token is not found".to_string()))?
.to_string()
.strip_prefix(&format!("{name}="))
.ok_or_else(|| Error::Unauthorized("error strip value".to_string()))?
.to_string())
}
/// Extract a token value from query
///
/// # Errors
/// when token value from cookie is not found
pub fn extract_token_from_query(name: &str, parts: &Parts) -> LocoResult<String> {
// LogoResult
let parameters: Query<HashMap<String, String>> =
Query::try_from_uri(&parts.uri).map_err(|err| Error::Unauthorized(err.to_string()))?;
parameters
.get(name)
.cloned()
.ok_or_else(|| Error::Unauthorized(format!("`{name}` query parameter not found")))
}

// ---------------------------------------
//
Expand All @@ -151,8 +217,7 @@ where
// Extracts `ApiToken` from the request parts.
async fn from_request_parts(parts: &mut Parts, state: &S) -> Result<Self, Error> {
// Extract API key from the request header.
let api_key = extract_token_from_header(&parts.headers)
.map_err(|e| Error::Unauthorized(e.to_string()))?;
let api_key = extract_token_from_header(&parts.headers)?;

// Convert the state reference to the application context.
let state: AppContext = AppContext::from_ref(state);
Expand All @@ -165,3 +230,50 @@ where
Ok(Self { user })
}
}

#[cfg(test)]
mod tests {

use insta::assert_debug_snapshot;
use rstest::rstest;

use super::*;
use crate::config;

#[rstest]
#[case("extract_from_default", "https://loco.rs", None)]
#[case("extract_from_bearer", "loco.rs", Some(config::JWTLocation::Bearer))]
#[case("extract_from_cookie", "https://loco.rs", Some(config::JWTLocation::Cookie{name: "loco_cookie_key".to_string()}))]
#[case("extract_from_query", "https://loco.rs?query_token=query_token_value&test=loco", Some(config::JWTLocation::Query{name: "query_token".to_string()}))]
fn can_extract_token(
#[case] test_name: &str,
#[case] url: &str,
#[case] location: Option<config::JWTLocation>,
) {
let jwt_config = JWTConfig {
location,
secret: String::new(),
expiration: 1,
};

let request = axum::http::Request::builder()
.uri(url)
.header(AUTH_HEADER, format!("{TOKEN_PREFIX} bearer_token_value"))
.header(
"Cookie",
format!("{}={}", "loco_cookie_key", "cookie_token_value"),
)
.body(())
.unwrap();
let (parts, ()) = request.into_parts();
assert_debug_snapshot!(test_name, extract_token(&jwt_config, &parts));

// expected error
let request = axum::http::Request::builder()
.uri("https://loco.rs")
.body(())
.unwrap();
let (parts, ()) = request.into_parts();
assert!(extract_token(&jwt_config, &parts).is_err());
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
---
source: src/controller/middleware/auth.rs
expression: "extract_token(&jwt_config, &parts)"
---
Ok(
" bearer_token_value",
)
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
---
source: src/controller/middleware/auth.rs
expression: "extract_token(&jwt_config, &parts)"
---
Ok(
"cookie_token_value",
)
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
---
source: src/controller/middleware/auth.rs
expression: "extract_token(&jwt_config, &parts)"
---
Ok(
" bearer_token_value",
)
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
---
source: src/controller/middleware/auth.rs
expression: "extract_token(&jwt_config, &parts)"
---
Ok(
"query_token_value",
)
Loading