From d3ed7c956ae49bd53c1b347f4199c9fca219d9e8 Mon Sep 17 00:00:00 2001 From: Chris Smith Date: Thu, 18 Apr 2024 14:38:04 -0400 Subject: [PATCH] feat: key not from file --- Cargo.toml | 2 +- src/app/credentials.rs | 2 +- src/serv_account/errors.rs | 23 +++- src/serv_account/jwt.rs | 87 ++++++++------- src/serv_account/mod.rs | 220 ++++++++++++++++++++++++------------- 5 files changed, 203 insertions(+), 131 deletions(-) diff --git a/Cargo.toml b/Cargo.toml index 2cc9a9e..9f3d53f 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -23,10 +23,10 @@ ring = "0.16.20" thiserror = "1.0.48" anyhow = "1.0.40" futures = { version = "0.3", features = ["executor"], optional = true } +tokio = { version = "1.33.0", features = ["test-util", "sync"] } [dev-dependencies] mockito = "1.2.0" -tokio = { version = "1.33.0", features = ["test-util"] } [features] app-blocking = ["dep:futures"] diff --git a/src/app/credentials.rs b/src/app/credentials.rs index 3f22007..a1a4104 100644 --- a/src/app/credentials.rs +++ b/src/app/credentials.rs @@ -30,7 +30,7 @@ pub struct OauthCredentials { impl OauthCredentials { pub fn redirect_uri(&self) -> Result<&String> { self.redirect_uris - .get(0) + .first() .ok_or(AuthError::RedirectUriCfgError) } } diff --git a/src/serv_account/errors.rs b/src/serv_account/errors.rs index cd0ea9a..b6ee7b8 100644 --- a/src/serv_account/errors.rs +++ b/src/serv_account/errors.rs @@ -1,10 +1,12 @@ -use std::result::Result as StdResult; +use reqwest::StatusCode; +use ring::error::{KeyRejected, Unspecified}; +use std::{io, path::PathBuf, result::Result as StdResult}; use thiserror::Error; #[derive(Debug, Error)] pub enum ServiceAccountError { - #[error("failed to read key file: {0}")] - ReadKey(String), + #[error("failed to read key file: {0}: {1}")] + ReadKey(PathBuf, io::Error), #[error("failed to de/serialize to json")] SerdeJson(#[from] serde_json::Error), @@ -13,13 +15,22 @@ pub enum ServiceAccountError { Base64Decode(#[from] base64::DecodeError), #[error("failed to create rsa key pair: {0}")] - RsaKeyPair(String), + RsaKeyPair(KeyRejected), #[error("failed to rsa sign: {0}")] - RsaSign(String), + RsaSign(Unspecified), #[error("failed to send request")] - HttpReqwest(#[from] reqwest::Error), + HttpRequest(reqwest::Error), + + #[error("failed to send request")] + HttpRequestUnsuccessful(StatusCode, std::result::Result), + + #[error("failed to get response JSON")] + HttpJson(reqwest::Error), + + #[error("response returned non-Bearer auth access token: {0}")] + AccessTokenNotBeaarer(String), } pub type Result = StdResult; diff --git a/src/serv_account/jwt.rs b/src/serv_account/jwt.rs index d61afb6..d60244d 100644 --- a/src/serv_account/jwt.rs +++ b/src/serv_account/jwt.rs @@ -1,9 +1,15 @@ use super::errors::{Result, ServiceAccountError}; +use base64::{engine::general_purpose, Engine as _}; +use ring::{ + rand, + signature::{self, RsaKeyPair}, +}; +use serde_derive::Deserialize; use serde_derive::Serialize; -#[derive(Clone, Debug, Default, Serialize)] +#[derive(Debug)] pub struct JwtToken { - private_key: String, + key_pair: RsaKeyPair, header: JwtHeader, payload: JwtPayload, } @@ -24,41 +30,36 @@ struct JwtPayload { iat: u64, } -use base64::{engine::general_purpose, Engine as _}; -use ring::{rand, signature}; -use serde_derive::Deserialize; - impl JwtToken { - /// Creates a new JWT token from a service account key file - pub fn from_file(key_path: &str) -> Result { - let private_key_content = std::fs::read(key_path) - .map_err(|err| ServiceAccountError::ReadKey(format!("{}: {}", err, key_path)))?; - - let key_data = serde_json::from_slice::(&private_key_content)?; - + /// Creates a new JWT token from a service account key + pub fn from_key(key: &ServiceAccountKey) -> Result { let iat = chrono::Utc::now().timestamp() as u64; let exp = iat + 3600; - let private_key = key_data + let private_key = key .private_key .replace('\n', "") .replace("-----BEGIN PRIVATE KEY-----", "") .replace("-----END PRIVATE KEY-----", ""); + let private_key = private_key.as_bytes(); + let decoded = general_purpose::STANDARD.decode(private_key)?; + let key_pair = RsaKeyPair::from_pkcs8(&decoded).map_err(ServiceAccountError::RsaKeyPair)?; + Ok(Self { header: JwtHeader { alg: String::from("RS256"), typ: String::from("JWT"), }, payload: JwtPayload { - iss: key_data.client_email, + iss: key.client_email.clone(), sub: None, scope: String::new(), - aud: key_data.token_uri, + aud: key.token_uri.clone(), exp, iat, }, - private_key, + key_pair, }) } @@ -100,54 +101,52 @@ impl JwtToken { /// Signs a message with the private key fn sign_rsa(&self, message: String) -> Result> { - let private_key = self.private_key.as_bytes(); - let decoded = general_purpose::STANDARD.decode(private_key)?; - - let key_pair = signature::RsaKeyPair::from_pkcs8(&decoded).map_err(|err| { - ServiceAccountError::RsaKeyPair(format!("failed tp create key pair: {}", err)) - })?; - // Sign the message, using PKCS#1 v1.5 padding and the SHA256 digest algorithm. let rng = rand::SystemRandom::new(); - let mut signature = vec![0; key_pair.public_modulus_len()]; - key_pair + let mut signature = vec![0; self.key_pair.public_modulus_len()]; + self.key_pair .sign( &signature::RSA_PKCS1_SHA256, &rng, message.as_bytes(), &mut signature, ) - .map_err(|err| ServiceAccountError::RsaSign(format!("{}", err)))?; + .map_err(ServiceAccountError::RsaSign)?; Ok(signature) } } #[allow(dead_code)] -#[derive(Debug, Deserialize)] -struct ServiceAccountKey { - r#type: String, - project_id: String, - private_key_id: String, - private_key: String, - client_email: String, - client_id: String, - auth_uri: String, - token_uri: String, - auth_provider_x509_cert_url: String, - client_x509_cert_url: String, - universe_domain: String, +#[derive(Debug, Clone, Deserialize, Serialize)] +pub struct ServiceAccountKey { + pub r#type: String, + pub project_id: String, + pub private_key_id: String, + pub private_key: String, + pub client_email: String, + pub client_id: String, + pub auth_uri: String, + pub token_uri: String, + pub auth_provider_x509_cert_url: String, + pub client_x509_cert_url: String, + pub universe_domain: String, } #[cfg(test)] mod tests { use super::*; - const SERVICE_ACCOUNT_KEY_PATH: &str = "test_fixtures/service-account-key.json"; + fn read_key() -> ServiceAccountKey { + serde_json::from_slice(include_bytes!( + "../../test_fixtures/service-account-key.json" + )) + .unwrap() + } #[test] fn test_jwt_token() { - let mut token = JwtToken::from_file(SERVICE_ACCOUNT_KEY_PATH).unwrap(); + let mut token = JwtToken::from_key(&read_key()).unwrap(); assert_eq!(token.header.alg, "RS256"); assert_eq!(token.header.typ, "JWT"); @@ -170,7 +169,7 @@ mod tests { fn test_sign_rsa() { let message = String::from("hello, world"); - let token = JwtToken::from_file(SERVICE_ACCOUNT_KEY_PATH).unwrap(); + let token = JwtToken::from_key(&read_key()).unwrap(); let signature = token.sign_rsa(message).unwrap(); assert_eq!(signature.len(), 256); @@ -178,7 +177,7 @@ mod tests { #[test] fn test_token_to_string() { - let token = JwtToken::from_file(SERVICE_ACCOUNT_KEY_PATH) + let token = JwtToken::from_key(&read_key()) .unwrap() .sub(String::from("some@email.com")) .scope(String::from("https://www.googleapis.com/auth/pubsub")); diff --git a/src/serv_account/mod.rs b/src/serv_account/mod.rs index 0e10d53..39e5cc4 100644 --- a/src/serv_account/mod.rs +++ b/src/serv_account/mod.rs @@ -1,50 +1,43 @@ -use chrono::Utc; +use self::errors::ServiceAccountError; +use chrono::{DateTime, Duration, Utc}; use errors::Result; use reqwest::Client as HttpClient; +use serde_derive::Deserialize; +use std::{path::Path, sync::Arc}; +use tokio::sync::RwLock; -use self::errors::ServiceAccountError; +pub use self::jwt::ServiceAccountKey; -mod errors; +pub mod errors; mod jwt; #[derive(Debug, Clone)] pub struct ServiceAccount { + http_client: HttpClient, + key: ServiceAccountKey, scopes: String, - key_path: String, user_email: Option, - - access_token: Option, - expires_at: Option, - - http_client: HttpClient, + access_token: Arc>>, } -#[derive(Debug, serde_derive::Deserialize)] -struct Token { - access_token: String, - expires_in: u64, - token_type: String, +#[derive(Debug, Clone)] +pub struct AccessToken { + pub bearer_token: String, + pub expires_at: DateTime, } -impl Token { - fn bearer_token(&self) -> String { - format!("{} {}", self.token_type, self.access_token) +impl ServiceAccount { + pub fn builder() -> ServiceAccountBuilder { + ServiceAccountBuilder::new() } -} -impl ServiceAccount { /// Creates a new service account from a key file and scopes - pub fn from_file(key_path: &str, scopes: Vec<&str>) -> Self { - Self { - scopes: scopes.join(" "), - key_path: key_path.to_string(), - user_email: None, - - access_token: None, - expires_at: None, - - http_client: HttpClient::new(), - } + pub fn from_file>(key_path: P, scopes: Vec<&str>) -> Result { + let bytes = std::fs::read(&key_path) + .map_err(|e| ServiceAccountError::ReadKey(key_path.as_ref().to_path_buf(), e))?; + let key = serde_json::from_slice::(&bytes) + .map_err(ServiceAccountError::SerdeJson)?; + Ok(Self::builder().key(key).scopes(scopes).build()) } /// Sets the user email @@ -56,60 +49,125 @@ impl ServiceAccount { /// Returns an access token /// If the access token is not expired, it will return the cached access token /// Otherwise, it will exchange the JWT token for an access token - pub async fn access_token(&mut self) -> Result { - match (self.access_token.as_ref(), self.expires_at) { - (Some(access_token), Some(expires_at)) - if expires_at > Utc::now().timestamp() as u64 => - { - Ok(access_token.to_string()) - } + pub async fn access_token(&self) -> Result { + let access_token = self.access_token.read().await.clone(); + match access_token { + Some(access_token) if access_token.expires_at > Utc::now() => Ok(access_token), _ => { - let jwt_token = self.jwt_token()?; - let token = match self.exchange_jwt_token_for_access_token(jwt_token).await { - Ok(token) => token, - Err(err) => return Err(err), - }; - - let expires_at = Utc::now().timestamp() as u64 + token.expires_in - 30; - - self.access_token = Some(token.bearer_token()); - self.expires_at = Some(expires_at); - - Ok(token.bearer_token()) + let new_token = self.get_fresh_access_token().await?; + *self.access_token.write().await = Some(new_token.clone()); + Ok(new_token) } } } - async fn exchange_jwt_token_for_access_token( - &mut self, - jwt_token: jwt::JwtToken, - ) -> Result { - let req_builder = self.http_client.post(jwt_token.token_uri()).form(&[ - ("grant_type", "urn:ietf:params:oauth:grant-type:jwt-bearer"), - ("assertion", &jwt_token.to_string()?), - ]); - - let res = match req_builder.send().await { - Ok(resp) => resp, - Err(err) => return Err(ServiceAccountError::HttpReqwest(err)), + async fn get_fresh_access_token(&self) -> Result { + let jwt_token = { + let mut token = jwt::JwtToken::from_key(&self.key)?.scope(self.scopes.clone()); + if let Some(user_email) = &self.user_email { + token = token.sub(user_email.clone()); + }; + token }; - let token = match res.json::().await { - Ok(token) => token, - Err(err) => return Err(ServiceAccountError::HttpReqwest(err)), - }; + #[derive(Debug, Deserialize)] + pub struct TokenResponse { + token_type: String, + access_token: String, + expires_in: i64, + } + + let response = self + .http_client + .post(jwt_token.token_uri()) + .form(&[ + ("grant_type", "urn:ietf:params:oauth:grant-type:jwt-bearer"), + ("assertion", &jwt_token.to_string()?), + ]) + .send() + .await + .map_err(ServiceAccountError::HttpRequest)?; + + if !response.status().is_success() { + return Err(ServiceAccountError::HttpRequestUnsuccessful( + response.status(), + response.text().await, + )); + } + + let json = response + .json::() + .await + .map_err(ServiceAccountError::HttpJson)?; + + if json.token_type != "Bearer" { + return Err(ServiceAccountError::AccessTokenNotBeaarer(json.token_type)); + } + + // Account for clock skew or time to receive or process the response + const LEEWAY: Duration = Duration::seconds(30); - Ok(token) + let expires_at = Utc::now() + Duration::seconds(json.expires_in) - LEEWAY; + + Ok(AccessToken { + bearer_token: json.access_token, + expires_at, + }) } +} - fn jwt_token(&self) -> Result { - let token = jwt::JwtToken::from_file(&self.key_path)?; +pub struct ServiceAccountBuilder { + http_client: Option, + key: Option, + scopes: Option, + user_email: Option, +} + +impl ServiceAccountBuilder { + pub fn new() -> Self { + Self { + http_client: None, + key: None, + scopes: None, + user_email: None, + } + } - Ok(match self.user_email { - Some(ref user_email) => token.sub(user_email.to_string()), - None => token, + /// Panics if key is not provided + pub fn build(self) -> ServiceAccount { + ServiceAccount { + http_client: self.http_client.unwrap_or_default(), + key: self.key.expect("Key required"), + scopes: self.scopes.unwrap_or_default(), + user_email: self.user_email, + access_token: Arc::new(RwLock::new(None)), } - .scope(self.scopes.clone())) + } + + pub fn http_client(mut self, http_client: HttpClient) -> Self { + self.http_client = Some(http_client); + self + } + + pub fn key(mut self, key: ServiceAccountKey) -> Self { + self.key = Some(key); + self + } + + pub fn scopes(mut self, scopes: Vec<&str>) -> Self { + self.scopes = Some(scopes.join(" ")); + self + } + + pub fn user_email>(mut self, user_email: S) -> Self { + self.user_email = Some(user_email.into()); + self + } +} + +impl Default for ServiceAccountBuilder { + fn default() -> Self { + Self::new() } } @@ -121,22 +179,26 @@ mod tests { async fn test_access_token() { let scopes = vec!["https://www.googleapis.com/auth/drive"]; let key_path = "test_fixtures/service-account-key.json"; - let mut service_account = ServiceAccount::from_file(key_path, scopes); + let service_account = ServiceAccount::from_file(key_path, scopes).unwrap(); // TODO: fix this test - make sure we can run an integration test // let access_token = service_account.access_token(); // assert!(access_token.is_ok()); // assert!(!access_token.unwrap().is_empty()); - service_account.access_token = Some("test_access_token".to_string()); - - let expires_at = Utc::now().timestamp() as u64 + 3600; - service_account.expires_at = Some(expires_at); + let expires_at = Utc::now() + Duration::seconds(3600); + *service_account.access_token.write().await = Some(AccessToken { + bearer_token: "test_access_token".to_string(), + expires_at, + }); assert_eq!( - service_account.access_token().await.unwrap(), + service_account.access_token().await.unwrap().bearer_token, "test_access_token" ); - assert_eq!(service_account.expires_at.unwrap(), expires_at); + assert_eq!( + service_account.access_token().await.unwrap().expires_at, + expires_at + ); } }