Skip to content

Commit

Permalink
feat! Update on interval
Browse files Browse the repository at this point in the history
  • Loading branch information
kjellkongsvik committed Sep 15, 2023
1 parent 4bf3110 commit 15c0371
Show file tree
Hide file tree
Showing 5 changed files with 170 additions and 110 deletions.
2 changes: 1 addition & 1 deletion Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ reqwest = { version = "0.11", features = ["json"] }
serde = { version = "1", features = ["derive"] }
thiserror = { version = "1" }
tracing = { version = "0.1" }
tokio = { version = "1", features = ["macros"] }

[dev-dependencies]
tokio = { version = "1", features = ["macros"] }
serde_json = { version = "1" }
8 changes: 4 additions & 4 deletions src/claims.rs
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@ use axum::{
};
use serde::de::DeserializeOwned;

use crate::{Jwks, Token, TokenError};
use crate::{KeyManager, Token, TokenError};

pub struct Claims<C: DeserializeOwned + ParseTokenClaims>(pub C);

Expand Down Expand Up @@ -82,16 +82,16 @@ pub trait ParseTokenClaims {
impl<S, C> FromRequestParts<S> for Claims<C>
where
C: DeserializeOwned + ParseTokenClaims,
Jwks: FromRef<S>,
KeyManager: FromRef<S>,
S: Send + Sync,
{
type Rejection = C::Rejection;

async fn from_request_parts(parts: &mut Parts, state: &S) -> Result<Self, Self::Rejection> {
let jwks = Jwks::from_ref(state);
let key_manager = KeyManager::from_ref(state);
let token = Token::from_request_parts(parts, state).await?;

let token_data = jwks.validate_claims(token.value())?;
let token_data = key_manager.validate_claims(token.value()).await?;

Ok(Claims(token_data.claims))
}
Expand Down
106 changes: 106 additions & 0 deletions src/key_manager.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,106 @@
use std::sync::Arc;

use crate::{KeyStore, TokenError};
use jsonwebtoken::{decode, decode_header, TokenData};
use serde::de::DeserializeOwned;

use tokio::time::{sleep, Duration};
use tokio::{sync::Mutex, time::Instant};
use tracing::{debug, info};

#[derive(Clone)]
pub struct KeyManager {
pub oidc_url: String,
pub audience: String,
pub minimal_interval: Duration,
pub key_store: Arc<Mutex<KeyStore>>,
}

impl KeyManager {
/// Create a new KeyManager that will fetch jwks from an oidc
///
/// The first fetching will be done either:
/// - on first call to `validate_claims`
/// - immediately if `with_update_interval` is used
///
/// If a kid from a token is not in the KeyStore and
/// the time since last update is > `minimal_interval`
/// a new fetch of jwks is executed in `validate_claims`
///
pub fn new(oidc_url: String, audience: String, minimal_interval: Duration) -> Self {
let key_store = Arc::new(Mutex::new(KeyStore::new()));
Self {
oidc_url,
audience,
minimal_interval,
key_store,
}
}

pub fn with_update_jwks_interval(self, interval: Duration) -> Self {
if interval.is_zero() {
panic!("interval must be > 0");
}
let key_store = self.key_store.clone();
let oidc_url = self.oidc_url.clone();
let audience = self.audience.clone();
tokio::spawn(async move {
let client = reqwest::Client::default();
loop {
{
let mut ks = key_store.lock().await;
match KeyStore::from_oidc_url(&client, &oidc_url, &audience).await {
Ok(new_ks) => *ks = new_ks,
Err(e) => info!(?e, "Could not update jwks from: {}", oidc_url),
}
}
sleep(interval).await;
}
});
self
}

pub async fn validate_claims<T>(&self, token: &str) -> Result<TokenData<T>, TokenError>
where
T: DeserializeOwned,
{
let header = decode_header(token).map_err(|error| {
debug!(?error, "Received token with invalid header.");

TokenError::InvalidHeader(error)
})?;
let kid = header.kid.as_ref().ok_or_else(|| {
debug!(?header, "Header is missing the `kid` attribute.");

TokenError::MissingKeyId
})?;

let mut ks = self.key_store.lock().await;
if ks.keys.get(kid).is_none() && ks.when + self.minimal_interval < Instant::now() {
let client = reqwest::Client::default();
match KeyStore::from_oidc_url(&client, &self.oidc_url, &self.audience).await {
Ok(new_ks) => *ks = new_ks,
Err(e) => info!(?e, "Could not update jwks from: {}", self.oidc_url),
}
if let Ok(new_ks) =
KeyStore::from_oidc_url(&client, &self.oidc_url, &self.audience).await
{
*ks = new_ks;
}
}
let key = ks.keys.get(kid).ok_or_else(|| {
debug!(%kid, "Token refers to an unknown key.");

TokenError::UnknownKeyId(kid.to_owned())
})?;

let decoded_token: TokenData<T> =
decode(token, &key.decoding, &key.validation).map_err(|error| {
debug!(?error, "Token is malformed or does not pass validation.");

TokenError::Invalid(error)
})?;

Ok(decoded_token)
}
}
137 changes: 44 additions & 93 deletions src/jwks.rs → src/key_store.rs
Original file line number Diff line number Diff line change
@@ -1,56 +1,46 @@
use std::{collections::HashMap, str::FromStr};

use jsonwebtoken::{
decode, decode_header,
jwk::{self, AlgorithmParameters},
DecodingKey, TokenData, Validation,
DecodingKey, Validation,
};
use serde::{de::DeserializeOwned, Deserialize};
use serde::Deserialize;
use thiserror::Error;
use tracing::{debug, info};

use crate::TokenError;
use tokio::time::{Duration, Instant};
use tracing::{debug, info};

/// A container for a set of JWT decoding keys.
///
/// The container can be used to validate any JWT that identifies a known key
/// through the `kid` attribute in the token's header.
#[derive(Clone)]
pub struct Jwks {
keys: HashMap<String, Jwk>,
pub struct KeyStore {
pub when: Instant,
pub keys: Keys,
}

#[derive(Deserialize)]
struct Oid {
jwks_uri: String,
id_token_signing_alg_values_supported: Option<Vec<String>>,
impl KeyStore {
/// Create an empty KeyStore with update time far in the past
/// This will trigger a refresh on first use
pub fn new() -> Self {
Self {
when: Instant::now() - Duration::from_secs(100000),
keys: Keys::new(),
}
}
}

impl Jwks {
/// Pull a JSON Web Key Set from a specific authority.

///
/// # Arguments
/// * `oidc_url` - The url with Openid-configuration.
/// * `audience` - The identifier of the consumer of the JWT. This will be
/// matched against the `aud` claim from the token.
///
/// # Return Value
/// The information needed to decode JWTs using any of the keys specified in
/// the authority's JWKS.
pub async fn from_oidc_url(oidc_url: &str, audience: String) -> Result<Self, JwksError> {
Self::from_oidc_url_with_client(&reqwest::Client::default(), oidc_url, audience).await
impl Default for KeyStore {
fn default() -> Self {
Self::new()
}
}

/// A version of [`from_oidc`][Self::from_oidc] that allows for
/// passing in a custom [`Client`][reqwest::Client].
pub async fn from_oidc_url_with_client(
impl KeyStore {
pub async fn from_oidc_url(
client: &reqwest::Client,
oidc_url: &str,
audience: String,
audience: &str,
) -> Result<Self, JwksError> {
debug!(%oidc_url, "Fetching openid-configuration.");
let oidc = client.get(oidc_url).send().await?.json::<Oid>().await?;
let oidc = client.get(oidc_url).send().await?.json::<Oidc>().await?;
let jwks_uri = oidc.jwks_uri;
let alg = match &oidc.id_token_signing_alg_values_supported {
Some(algs) => match algs.first() {
Expand All @@ -60,35 +50,19 @@ impl Jwks {
_ => None,
};

Self::from_jwks_url_with_client(&reqwest::Client::default(), &jwks_uri, audience, alg).await
let keys = Self::from_jwks_url(client, &jwks_uri, audience, alg).await?;
Ok(Self {
keys,
when: Instant::now(),
})
}

///
/// # Arguments
/// * `jwks_url` - The url which JWKS info is pulled from.
/// * `audience` - The identifier of the consumer of the JWT. This will be
/// matched against the `aud` claim from the token.
/// * `alg` - The alg to use if not specified in JWK
///
/// # Return Value
/// The information needed to decode JWTs using any of the keys specified in
/// the authority's JWKS.
pub async fn from_jwks_url(
jwks_url: &str,
audience: String,
alg: Option<jsonwebtoken::Algorithm>,
) -> Result<Self, JwksError> {
Self::from_jwks_url_with_client(&reqwest::Client::default(), jwks_url, audience, alg).await
}

/// A version of [`from_jwks`][Self::from_jwks] that allows for
/// passing in a custom [`Client`][reqwest::Client].
pub async fn from_jwks_url_with_client(
async fn from_jwks_url(
client: &reqwest::Client,
jwks_url: &str,
audience: String,
audience: &str,
alg: Option<jsonwebtoken::Algorithm>,
) -> Result<Self, JwksError> {
) -> Result<Keys, JwksError> {
debug!(%jwks_url, "Fetching JSON Web Key Set.");
let jwks: jwk::JwkSet = client.get(jwks_url).send().await?.json().await?;
info!(
Expand Down Expand Up @@ -135,45 +109,22 @@ impl Jwks {
}
}

Ok(Self { keys })
Ok(keys)
}
}

pub fn validate_claims<T>(&self, token: &str) -> Result<TokenData<T>, TokenError>
where
T: DeserializeOwned,
{
let header = decode_header(token).map_err(|error| {
debug!(?error, "Received token with invalid header.");

TokenError::InvalidHeader(error)
})?;
let kid = header.kid.as_ref().ok_or_else(|| {
debug!(?header, "Header is missing the `kid` attribute.");

TokenError::MissingKeyId
})?;

let key = self.keys.get(kid).ok_or_else(|| {
debug!(%kid, "Token refers to an unknown key.");

TokenError::UnknownKeyId(kid.to_owned())
})?;

let decoded_token: TokenData<T> =
decode(token, &key.decoding, &key.validation).map_err(|error| {
debug!(?error, "Token is malformed or does not pass validation.");

TokenError::Invalid(error)
})?;

Ok(decoded_token)
}
#[derive(Deserialize)]
struct Oidc {
jwks_uri: String,
id_token_signing_alg_values_supported: Option<Vec<String>>,
}

type Keys = HashMap<String, Jwk>;

#[derive(Clone)]
struct Jwk {
decoding: DecodingKey,
validation: Validation,
pub struct Jwk {
pub decoding: DecodingKey,
pub validation: Validation,
}

/// An error with the overall set of JSON Web Keys.
Expand All @@ -193,7 +144,7 @@ pub enum JwksError {
InvalidAlgorithm(#[from] jsonwebtoken::errors::Error),
}

/// An error with a specific key from a JWKS.
// An error with a specific key from a JWKS.
#[derive(Debug, Error)]
pub enum JwkError {
/// There was an error constructing the decoding key from the RSA components
Expand Down
27 changes: 15 additions & 12 deletions src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -15,18 +15,19 @@
//! Json,
//! Router,
//! };
//! use axum_jwks::{Claims, Jwks, ParseTokenClaims, TokenError};
//! use axum_jwks::{Claims, KeyManager, ParseTokenClaims, TokenError};
//! use serde::{Deserialize, Serialize};
//! use tokio::time::Duration;
//!
//! // The state available to all your route handlers.
//! #[derive(Clone)]
//! struct AppState {
//! jwks: Jwks,
//! key_manager: KeyManager,
//! }
//!
//! impl FromRef<AppState> for Jwks {
//! impl FromRef<AppState> for KeyManager {
//! fn from_ref(state: &AppState) -> Self {
//! state.jwks.clone()
//! state.key_manager.clone()
//! }
//! }
//!
Expand Down Expand Up @@ -71,26 +72,28 @@
//! }
//!
//! async fn create_router() -> Router<AppState> {
//! let jwks = Jwks::from_oidc_url(
//! let key_manager = KeyManager::new(
//! // The Authorization Server that signs the JWTs you want to consume.
//! "https://my-auth-server.example.com/.well-known/openid-configuration",
//! "https://my-auth-server.example.com/.well-known/openid-configuration".to_owned(),
//! // The audience identifier for the application. This ensures that
//! // JWTs are intended for this application.
//! "https://my-api-identifier.example.com/".to_owned(),
//! )
//! .await
//! .unwrap();
//! //
//! Duration::from_secs(600),
//! );
//!
//! Router::new()
//! .route("/echo-claims", get(echo_claims))
//! .with_state(AppState { jwks })
//! .with_state(AppState { key_manager })
//! }
//! ```

mod claims;
mod jwks;
mod key_manager;
mod key_store;
mod token;

pub use claims::{Claims, ParseTokenClaims};
pub use jwks::{JwkError, Jwks, JwksError};
pub use key_manager::KeyManager;
pub use key_store::KeyStore;
pub use token::{Token, TokenError};

0 comments on commit 15c0371

Please sign in to comment.