-
Notifications
You must be signed in to change notification settings - Fork 3
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
1 parent
4bf3110
commit 15c0371
Showing
5 changed files
with
170 additions
and
110 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,106 @@ | ||
use std::sync::Arc; | ||
|
||
use crate::{KeyStore, TokenError}; | ||
use jsonwebtoken::{decode, decode_header, TokenData}; | ||
use serde::de::DeserializeOwned; | ||
|
||
use tokio::time::{sleep, Duration}; | ||
use tokio::{sync::Mutex, time::Instant}; | ||
use tracing::{debug, info}; | ||
|
||
#[derive(Clone)] | ||
pub struct KeyManager { | ||
pub oidc_url: String, | ||
pub audience: String, | ||
pub minimal_interval: Duration, | ||
pub key_store: Arc<Mutex<KeyStore>>, | ||
} | ||
|
||
impl KeyManager { | ||
/// Create a new KeyManager that will fetch jwks from an oidc | ||
/// | ||
/// The first fetching will be done either: | ||
/// - on first call to `validate_claims` | ||
/// - immediately if `with_update_interval` is used | ||
/// | ||
/// If a kid from a token is not in the KeyStore and | ||
/// the time since last update is > `minimal_interval` | ||
/// a new fetch of jwks is executed in `validate_claims` | ||
/// | ||
pub fn new(oidc_url: String, audience: String, minimal_interval: Duration) -> Self { | ||
let key_store = Arc::new(Mutex::new(KeyStore::new())); | ||
Self { | ||
oidc_url, | ||
audience, | ||
minimal_interval, | ||
key_store, | ||
} | ||
} | ||
|
||
pub fn with_update_jwks_interval(self, interval: Duration) -> Self { | ||
if interval.is_zero() { | ||
panic!("interval must be > 0"); | ||
} | ||
let key_store = self.key_store.clone(); | ||
let oidc_url = self.oidc_url.clone(); | ||
let audience = self.audience.clone(); | ||
tokio::spawn(async move { | ||
let client = reqwest::Client::default(); | ||
loop { | ||
{ | ||
let mut ks = key_store.lock().await; | ||
match KeyStore::from_oidc_url(&client, &oidc_url, &audience).await { | ||
Ok(new_ks) => *ks = new_ks, | ||
Err(e) => info!(?e, "Could not update jwks from: {}", oidc_url), | ||
} | ||
} | ||
sleep(interval).await; | ||
} | ||
}); | ||
self | ||
} | ||
|
||
pub async fn validate_claims<T>(&self, token: &str) -> Result<TokenData<T>, TokenError> | ||
where | ||
T: DeserializeOwned, | ||
{ | ||
let header = decode_header(token).map_err(|error| { | ||
debug!(?error, "Received token with invalid header."); | ||
|
||
TokenError::InvalidHeader(error) | ||
})?; | ||
let kid = header.kid.as_ref().ok_or_else(|| { | ||
debug!(?header, "Header is missing the `kid` attribute."); | ||
|
||
TokenError::MissingKeyId | ||
})?; | ||
|
||
let mut ks = self.key_store.lock().await; | ||
if ks.keys.get(kid).is_none() && ks.when + self.minimal_interval < Instant::now() { | ||
let client = reqwest::Client::default(); | ||
match KeyStore::from_oidc_url(&client, &self.oidc_url, &self.audience).await { | ||
Ok(new_ks) => *ks = new_ks, | ||
Err(e) => info!(?e, "Could not update jwks from: {}", self.oidc_url), | ||
} | ||
if let Ok(new_ks) = | ||
KeyStore::from_oidc_url(&client, &self.oidc_url, &self.audience).await | ||
{ | ||
*ks = new_ks; | ||
} | ||
} | ||
let key = ks.keys.get(kid).ok_or_else(|| { | ||
debug!(%kid, "Token refers to an unknown key."); | ||
|
||
TokenError::UnknownKeyId(kid.to_owned()) | ||
})?; | ||
|
||
let decoded_token: TokenData<T> = | ||
decode(token, &key.decoding, &key.validation).map_err(|error| { | ||
debug!(?error, "Token is malformed or does not pass validation."); | ||
|
||
TokenError::Invalid(error) | ||
})?; | ||
|
||
Ok(decoded_token) | ||
} | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters