diff --git a/.github/workflows/ci.yaml b/.github/workflows/ci.yaml new file mode 100644 index 0000000..c21d493 --- /dev/null +++ b/.github/workflows/ci.yaml @@ -0,0 +1,13 @@ +name: Cargo checks +on: + push: + pull_request: +jobs: + clippy: + runs-on: ubuntu-latest + steps: + - uses: actions/checkout@v4 + - run: curl --proto '=https' --tlsv1.2 -sSf https://sh.rustup.rs | sh -s -- -y + - run: cargo clippy --workspace --all-features --all-targets -- -D warnings + - run: cargo test --workspace --all-features --all-targets + - run: cargo fmt -- --check diff --git a/Cargo.toml b/Cargo.toml index 02e74fc..0029c75 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -1,7 +1,7 @@ [package] description = "HTTP Client for Google OAuth2" name = "gauth" -version = "0.8.0" +version = "0.9.0" authors = ["Simon Makarski "] edition = "2021" license = "MIT OR Apache-2.0" @@ -17,21 +17,21 @@ serde = "1" serde_json = "1" serde_derive = "1" dirs = "5.0.1" -reqwest = { version = "0.11", features = ["json"] } +reqwest = { version = "0.12.4", features = ["json"] } chrono = "0.4.31" base64 = "0.21.3" ring = "0.16.20" thiserror = "1.0.48" anyhow = "1.0.40" futures = { version = "0.3", features = ["executor"], optional = true } -tokio = { version = "1.33.0", optional = true } +tokio = { version = "1.33.0", features = ["test-util", "sync"] } log = { version = "0.4", optional = true } [dev-dependencies] mockito = "1.2.0" -tokio = { version = "1.33.0", features = ["test-util"] } +tokio = { version = "1.33.0", features = ["test-util", "rt", "macros", "rt-multi-thread"] } env_logger = "0.10.0" [features] app-blocking = ["dep:futures"] -token-watcher = ["dep:tokio", "dep:async-trait", "dep:log"] +token-watcher = ["dep:async-trait", "dep:log"] diff --git a/README.md b/README.md index 70943b7..0f18c4f 100644 --- a/README.md +++ b/README.md @@ -11,7 +11,7 @@ The library supports the following Google Auth flows: ```toml [dependencies] -gauth = "0.8" +gauth = "0.9" ``` #### OAuth2 @@ -45,7 +45,7 @@ It is also possible to make a **blocking call** to retrieve an access token. Thi ``` [dependencies] -gauth = { version = "0.8", features = ["app-blocking"] } +gauth = { version = "0.9", features = ["app-blocking"] } ``` ```rust,no_run @@ -123,7 +123,7 @@ To resolve this, we adopted an experimental approach by developing a `token_prov ``` [dependencies] -gauth = { version = "0.8", features = ["token-watcher"] } +gauth = { version = "0.9", features = ["token-watcher"] } ``` ```rust,no_run diff --git a/examples/async_token_provider.rs b/examples/async_token_provider.rs index f9cbbb1..46bbd5f 100644 --- a/examples/async_token_provider.rs +++ b/examples/async_token_provider.rs @@ -12,8 +12,11 @@ async fn main() -> Result<(), Box> { .nth(1) .expect("Provide a path to the service account key file"); - let service_account = - ServiceAccount::from_file(&keypath, vec!["https://www.googleapis.com/auth/pubsub"]); + let service_account = ServiceAccount::from_file(&keypath) + .unwrap() + .scopes(vec!["https://www.googleapis.com/auth/pubsub"]) + .build() + .unwrap(); let tp = AsyncTokenProvider::new(service_account).with_interval(5); diff --git a/src/app/mod.rs b/src/app/mod.rs index 8e2a80a..1e43db4 100644 --- a/src/app/mod.rs +++ b/src/app/mod.rs @@ -90,7 +90,7 @@ impl Auth { /// App_name can be used to override the default app name pub fn app_name(mut self, app_name: &str) -> Self { - self.app_name = app_name.to_owned(); + app_name.clone_into(&mut self.app_name); self } @@ -260,7 +260,7 @@ mod tests { #[tokio::test] async fn test_access_token_success() { - let mut google = mockito::Server::new(); + let mut google = mockito::Server::new_async().await; let google_host = google.url(); google diff --git a/src/serv_account/errors.rs b/src/serv_account/errors.rs index cd0ea9a..fd31bb9 100644 --- a/src/serv_account/errors.rs +++ b/src/serv_account/errors.rs @@ -1,25 +1,57 @@ -use std::result::Result as StdResult; +use reqwest::StatusCode; +use ring::error::{KeyRejected, Unspecified}; +use std::{io, path::PathBuf}; use thiserror::Error; #[derive(Debug, Error)] -pub enum ServiceAccountError { - #[error("failed to read key file: {0}")] - ReadKey(String), +pub enum ServiceAccountFromFileError { + #[error("failed to read key file: {0}: {1}")] + ReadFile(PathBuf, io::Error), #[error("failed to de/serialize to json")] - SerdeJson(#[from] serde_json::Error), + DeserializeFile(#[from] serde_json::Error), - #[error("failed to decode base64")] - Base64Decode(#[from] base64::DecodeError), + #[error("Failed to initialize service account: {0}")] + ServiceAccountInitialization(ServiceAccountBuildError), - #[error("failed to create rsa key pair: {0}")] - RsaKeyPair(String), + #[error("Failed to get access token: {0}")] + GetAccessToken(GetAccessTokenError), +} + +#[derive(Debug, Error)] +pub enum ServiceAccountBuildError { + #[error("RSA private key didn't start with PEM prefix: -----BEGIN PRIVATE KEY-----")] + RsaPrivateKeyNoPrefix, - #[error("failed to rsa sign: {0}")] - RsaSign(String), + #[error("RSA private key didn't end with PEM suffix: -----END PRIVATE KEY-----")] + RsaPrivateKeyNoSuffix, - #[error("failed to send request")] - HttpReqwest(#[from] reqwest::Error), + #[error("RSA private key could not be decoded as base64: {0}")] + RsaPrivateKeyDecode(base64::DecodeError), + + #[error("RSA private key could not be parsed: {0}")] + RsaPrivateKeyParse(KeyRejected), } -pub type Result = StdResult; +#[derive(Debug, Error)] +pub enum GetAccessTokenError { + #[error("failed to serialize JSON: {0}")] + JsonSerialization(serde_json::Error), + + #[error("failed to RSA sign: {0}")] + RsaSign(Unspecified), + + #[error("failed to send request")] + HttpRequest(reqwest::Error), + + #[error("failed to send request")] + HttpRequestUnsuccessful(StatusCode, std::result::Result), + + #[error("failed to get response JSON")] + HttpJson(reqwest::Error), + + #[error("response returned non-Bearer auth access token: {0}")] + AccessTokenNotBearer(String), + + // TODO error variant for invalid authentication +} diff --git a/src/serv_account/jwt.rs b/src/serv_account/jwt.rs index d61afb6..741d426 100644 --- a/src/serv_account/jwt.rs +++ b/src/serv_account/jwt.rs @@ -1,194 +1,209 @@ -use super::errors::{Result, ServiceAccountError}; +use super::errors::{GetAccessTokenError, ServiceAccountBuildError}; +use base64::{engine::general_purpose, Engine as _}; +use ring::{ + rand, + signature::{self, RsaKeyPair}, +}; +use serde_derive::Deserialize; use serde_derive::Serialize; +use std::sync::Arc; -#[derive(Clone, Debug, Default, Serialize)] -pub struct JwtToken { - private_key: String, - header: JwtHeader, - payload: JwtPayload, -} - -#[derive(Clone, Debug, Default, Serialize)] -struct JwtHeader { - alg: String, - typ: String, -} - -#[derive(Clone, Debug, Default, Serialize)] -struct JwtPayload { +#[derive(Debug, Clone)] +pub struct JwtTokenSigner { + key_pair: Arc, + rng: rand::SystemRandom, iss: String, sub: Option, scope: String, aud: String, - exp: u64, - iat: u64, } -use base64::{engine::general_purpose, Engine as _}; -use ring::{rand, signature}; -use serde_derive::Deserialize; - -impl JwtToken { - /// Creates a new JWT token from a service account key file - pub fn from_file(key_path: &str) -> Result { - let private_key_content = std::fs::read(key_path) - .map_err(|err| ServiceAccountError::ReadKey(format!("{}: {}", err, key_path)))?; - - let key_data = serde_json::from_slice::(&private_key_content)?; - - let iat = chrono::Utc::now().timestamp() as u64; - let exp = iat + 3600; - - let private_key = key_data - .private_key - .replace('\n', "") - .replace("-----BEGIN PRIVATE KEY-----", "") - .replace("-----END PRIVATE KEY-----", ""); +impl JwtTokenSigner { + /// Creates a new JWT token from a service account key + pub fn from_key( + key: ServiceAccountKey, + scope: String, + sub: Option, + ) -> Result { + let no_whitespace = key.private_key.replace('\n', ""); + let private_key = no_whitespace + .strip_prefix("-----BEGIN PRIVATE KEY-----") + .ok_or(ServiceAccountBuildError::RsaPrivateKeyNoPrefix)? + .strip_suffix("-----END PRIVATE KEY-----") + .ok_or(ServiceAccountBuildError::RsaPrivateKeyNoSuffix)?; + println!("private_key: {:?}", private_key); + + let decoded = general_purpose::STANDARD + .decode(private_key.as_bytes()) + .map_err(ServiceAccountBuildError::RsaPrivateKeyDecode)?; + let key_pair = RsaKeyPair::from_pkcs8(&decoded) + .map_err(ServiceAccountBuildError::RsaPrivateKeyParse)?; Ok(Self { - header: JwtHeader { - alg: String::from("RS256"), - typ: String::from("JWT"), - }, - payload: JwtPayload { - iss: key_data.client_email, - sub: None, - scope: String::new(), - aud: key_data.token_uri, - exp, - iat, - }, - private_key, + iss: key.client_email, + rng: rand::SystemRandom::new(), + sub, + scope, + aud: key.token_uri, + key_pair: Arc::new(key_pair), }) } - /// Returns a JWT token string - pub fn to_string(&self) -> Result { - let header = serde_json::to_vec(&self.header)?; - let payload = serde_json::to_vec(&self.payload)?; - - let base64_header = general_purpose::STANDARD.encode(header); - let base64_payload = general_purpose::STANDARD.encode(payload); - - let raw_signature = format!("{}.{}", base64_header, base64_payload); - let signature = self.sign_rsa(raw_signature)?; + /// Returns a signed JWT token string + pub fn sign(&self) -> Result { + #[derive(Clone, Debug, Default, Serialize)] + struct JwtHeader<'a> { + alg: &'a str, + typ: &'a str, + } + let header = serde_json::to_vec(&JwtHeader { + alg: "RS256", + typ: "JWT", + }) + .map_err(GetAccessTokenError::JsonSerialization)?; + let header = general_purpose::STANDARD.encode(header); + + #[derive(Clone, Debug, Default, Serialize)] + struct JwtPayload<'a> { + iss: &'a str, + sub: Option<&'a str>, + scope: &'a str, + aud: &'a str, + exp: u64, + iat: u64, + } + let iat = chrono::Utc::now().timestamp() as u64; + let exp = iat + 3600; + let payload = serde_json::to_vec(&JwtPayload { + iss: &self.iss, + sub: self.sub.as_deref(), + scope: &self.scope, + aud: &self.aud, + exp, + iat, + }) + .map_err(GetAccessTokenError::JsonSerialization)?; + let payload = general_purpose::STANDARD.encode(payload); - let base64_signature = general_purpose::STANDARD.encode(signature); + let to_sign = format!("{header}.{payload}"); + let signature = + sign_rsa(&self.key_pair, &self.rng, &to_sign).map_err(GetAccessTokenError::RsaSign)?; + let signature = general_purpose::STANDARD.encode(signature); - Ok(format!( - "{}.{}.{}", - base64_header, base64_payload, base64_signature - )) + Ok(format!("{to_sign}.{signature}")) } /// Returns the token uri pub fn token_uri(&self) -> &str { - &self.payload.aud + &self.aud } +} - /// Sets the sub field in the payload - pub fn sub(mut self, sub: String) -> Self { - self.payload.sub = Some(sub); - self - } - - /// Sets the scope field in the payload - pub fn scope(mut self, scope: String) -> Self { - self.payload.scope = scope; - self - } - - /// Signs a message with the private key - fn sign_rsa(&self, message: String) -> Result> { - let private_key = self.private_key.as_bytes(); - let decoded = general_purpose::STANDARD.decode(private_key)?; - - let key_pair = signature::RsaKeyPair::from_pkcs8(&decoded).map_err(|err| { - ServiceAccountError::RsaKeyPair(format!("failed tp create key pair: {}", err)) - })?; - - // Sign the message, using PKCS#1 v1.5 padding and the SHA256 digest algorithm. - let rng = rand::SystemRandom::new(); - let mut signature = vec![0; key_pair.public_modulus_len()]; - key_pair - .sign( - &signature::RSA_PKCS1_SHA256, - &rng, - message.as_bytes(), - &mut signature, - ) - .map_err(|err| ServiceAccountError::RsaSign(format!("{}", err)))?; - - Ok(signature) - } +/// Signs a message with the private key +fn sign_rsa( + key_pair: &RsaKeyPair, + rng: &dyn rand::SecureRandom, + message: &str, +) -> Result, ring::error::Unspecified> { + // Sign the message, using PKCS#1 v1.5 padding and the SHA256 digest algorithm. + let mut signature = vec![0; key_pair.public_modulus_len()]; + key_pair.sign( + &signature::RSA_PKCS1_SHA256, + rng, + message.as_bytes(), + &mut signature, + )?; + + Ok(signature) } #[allow(dead_code)] -#[derive(Debug, Deserialize)] -struct ServiceAccountKey { - r#type: String, - project_id: String, - private_key_id: String, - private_key: String, - client_email: String, - client_id: String, - auth_uri: String, - token_uri: String, - auth_provider_x509_cert_url: String, - client_x509_cert_url: String, - universe_domain: String, +#[derive(Debug, Clone, Deserialize, Serialize)] +pub struct ServiceAccountKey { + pub r#type: String, + pub project_id: String, + pub private_key_id: String, + pub private_key: String, + pub client_email: String, + pub client_id: String, + pub auth_uri: String, + pub token_uri: String, + pub auth_provider_x509_cert_url: String, + pub client_x509_cert_url: String, + pub universe_domain: String, } #[cfg(test)] mod tests { use super::*; + use serde_json::Value; - const SERVICE_ACCOUNT_KEY_PATH: &str = "test_fixtures/service-account-key.json"; + fn read_key() -> ServiceAccountKey { + serde_json::from_slice(include_bytes!( + "../../test_fixtures/service-account-key.json" + )) + .unwrap() + } #[test] - fn test_jwt_token() { - let mut token = JwtToken::from_file(SERVICE_ACCOUNT_KEY_PATH).unwrap(); - - assert_eq!(token.header.alg, "RS256"); - assert_eq!(token.header.typ, "JWT"); - assert!(token.payload.iss.contains("iam.gserviceaccount.com")); - assert_eq!(token.payload.sub, None); - assert_eq!(token.payload.scope, ""); - assert_eq!(token.payload.aud, "https://oauth2.googleapis.com/token"); - assert!(token.payload.exp > 0); - assert_eq!(token.payload.iat, token.payload.exp - 3600); - - token = token - .sub(String::from("some@email.domain")) - .scope(String::from("test_scope1 test_scope2 test_scope3")); - - assert_eq!(token.payload.sub, Some(String::from("some@email.domain"))); - assert_eq!(token.payload.scope, "test_scope1 test_scope2 test_scope3"); + fn test_rsa_sign() { + let key = "MIIEvAIBADANBgkqhkiG9w0BAQEFAASCBKYwggSiAgEAAoIBAQCr/KzFiWfiw5vd8KrFPmsktUfmba4x8r0uPDxxdeI/zrENHPkef3Zd3Tt4bvdG4VRWAQ/zuomHcksTW1AYaaS/TfoiH5c/xivWptKHGS/eh91SgPunmoK9wbvdNW8C4goVdw57JUz6IG1vZpenHjI7ofHMfg+2cBiTsTSWFDnd1EoNkK2lmdP1R5lzxNSRce9HgugKvHAcvDtB2goL9coo8y+3kyBTiS5qCgpWplGwIMBACGW6U4a//GajvmvvZyfym7OXJeqjXznjNH32ghhjcP2DUuGf36wika1rOpmZKCJDKBoMPQERUDa1ydYLfY3v1g/8xFTL4ezuyYEkGuu5AgMBAAECggEAP3Meglno+53SuRR6y/31JTvD5Nz98Otuo8oROoKVD5k/dGkF9xxrHMHrmMjHbVzf8kK+Edr1tgSScfe0Gu2OnA02hLRG5n5D2hL9hF3kbSKOokt3jCPSrBL3Leryo4uk0Lp1mzTtqzGfbgPZWwwm2B0syZaQUWwVhRdRITUhDBcUW8cuxGXzNeDTJMUjij0li61H62rJFjE5nyxCpwlukqR96uVWN6wXhM4xhzwhaHt6oGVUAENG3Er+ZjYCgBISQkEuiaFUgB3Zkv3qYWhaWNhwhO6MDsT33xex4Ecw4epCrAfEirkP1AIYmVWFw3uxODOJ/u8mb6IQIobnxwRiIQKBgQDihX+XxV8tSvHxgHTN5vzp4oOgnKhmiClm7/MSbjwHjLcffWh6gqBLbPAvcrfA0aewIT29xgIO0CpygJcg/4RND30YKTilYo7/ieTkdwRYsCbt9zM/WBop1snZja4Zox/SK23u4OJ4uUw0e4onXOOzAogCtiEKMx+U6+JmsyhNFQKBgQDCXmAhdrinbfXtsC5J+HwC81XaFujE2l4EiLqVaHH6DIrVTNSucf6O/nsCHWhttb3U7xT7CIHCe1om8peKZsjuiQqmlKjeqPRhDNlLXV5TadIKUs8svPM+MUXArhTc3vAv1pArhi7RpQ5F1AeTJGkOvxcY6vmMjXIb/dSiZMp1FQKBgDIii+fidjtHEB98Z92+lxGI4cslgRwYXNl8mBbnMQAWw90DW6Fp0eJ/vPUzdboGbQ/Ne6XJ8mCm8A4hqdFS3ExV9kDntrLcCnxCX9e1A9BBRIx8nuoRLNE/ybMN6Y+hDATvOciaG2XO1S/0e9JUe8z97W50MwHX6NCEGLrUQkI1AoGADD4lj/YKa4FhnDccs0wTg5wQLEyFHOEkSuTR29dYVoeztvu/6b0Ea71bwiZYDZEFBASLLcS7Z6SdaRaetPkEbwHyyctTV7MMsZA9n6Gh718a+8t7gTXlnGU+H4TXi5H/TwQU0KkDCfF7lKpmT75bX7Jpoggq7895AIpcel4e4oECgYAbddARaP5mH2KAiSoBUlvh4P2beCv5HmWjIhS2nA7KaGOtGfOk9/VGTRLZXtPed70cGD5SrgMze3umI37nAtcVv+MHcZSXhjoSQZ6M3GChaDUwJNC+f6GVjfadn7LOsY5L1+0cu1pe6r4uXBOwmvv1tynpY6sGOE+tPJibK5Pm8Q=="; + let key_pair = RsaKeyPair::from_pkcs8(&general_purpose::STANDARD.decode(key).unwrap()) + .expect("Failed to parse key"); + let rng = rand::SystemRandom::new(); + let message = "hello world"; + let signature = sign_rsa(&key_pair, &rng, message).unwrap(); + assert_eq!(signature.len(), 256); } #[test] - fn test_sign_rsa() { - let message = String::from("hello, world"); - - let token = JwtToken::from_file(SERVICE_ACCOUNT_KEY_PATH).unwrap(); - let signature = token.sign_rsa(message).unwrap(); + fn test_sign() { + let scope = "test_scope1 test_scope2 test_scope3"; + let signer = JwtTokenSigner::from_key(read_key(), scope.to_owned(), None).unwrap(); + let token = signer.sign().unwrap(); + println!("token: {:?}", token); + let parts = token.split('.').collect::>(); + assert_eq!(parts.len(), 3); + let mut parts = parts.into_iter(); + + let header = parts.next().unwrap(); + let header = general_purpose::STANDARD.decode(header).unwrap(); + let header = serde_json::from_slice::(&header).unwrap(); + assert_eq!(header["alg"], "RS256"); + assert_eq!(header["typ"], "JWT"); + + let payload = parts.next().unwrap(); + let payload = general_purpose::STANDARD.decode(payload).unwrap(); + let payload = serde_json::from_slice::(&payload).unwrap(); + assert_eq!(payload["scope"], Value::String(scope.to_owned())); + assert_eq!(payload["sub"], Value::Null); + assert_eq!(payload["aud"], "https://oauth2.googleapis.com/token"); + assert!(payload["exp"].as_i64().unwrap() > 0); + assert_eq!( + payload["iat"].as_i64().unwrap(), + payload["exp"].as_i64().unwrap() - 3600 + ); + let signature = parts.next().unwrap(); + let signature = general_purpose::STANDARD.decode(signature).unwrap(); assert_eq!(signature.len(), 256); } #[test] - fn test_token_to_string() { - let token = JwtToken::from_file(SERVICE_ACCOUNT_KEY_PATH) - .unwrap() - .sub(String::from("some@email.com")) - .scope(String::from("https://www.googleapis.com/auth/pubsub")); - - let token_string = token.to_string(); - - assert!(token_string.is_ok(), "token string successfully created"); - assert!( - !token_string.unwrap().is_empty(), - "token string is not empty" - ); + fn test_sign_email() { + let sub = "some@email.domain"; + let signer = + JwtTokenSigner::from_key(read_key(), "".to_owned(), Some(sub.to_owned())).unwrap(); + let token = signer.sign().unwrap(); + let parts = token.split('.').collect::>(); + assert_eq!(parts.len(), 3); + let mut parts = parts.into_iter(); + + let _header = parts.next().unwrap(); + + let payload = parts.next().unwrap(); + let payload = general_purpose::STANDARD.decode(payload).unwrap(); + let payload = serde_json::from_slice::(&payload).unwrap(); + assert_eq!(payload["sub"], Value::String(sub.to_owned())); } } diff --git a/src/serv_account/mod.rs b/src/serv_account/mod.rs index 2863b12..63053c3 100644 --- a/src/serv_account/mod.rs +++ b/src/serv_account/mod.rs @@ -1,115 +1,166 @@ -use chrono::Utc; -use errors::Result; +use self::{ + errors::{ + GetAccessTokenError, ServiceAccountBuildError as ServiceAccountBuilderError, + ServiceAccountFromFileError, + }, + jwt::JwtTokenSigner, +}; +use chrono::{DateTime, Duration, Utc}; use reqwest::Client as HttpClient; +use serde_derive::Deserialize; +use std::{path::Path, sync::Arc}; +use tokio::sync::RwLock; -use self::errors::ServiceAccountError; +pub use self::jwt::ServiceAccountKey; -pub(crate) mod errors; +pub mod errors; mod jwt; #[derive(Debug, Clone)] pub struct ServiceAccount { - scopes: String, - key_path: String, - user_email: Option, - - access_token: Option, - expires_at: Option, - http_client: HttpClient, + jwt_token: JwtTokenSigner, + access_token: Arc>>, } -#[derive(Debug, serde_derive::Deserialize)] -struct Token { - access_token: String, - expires_in: u64, - token_type: String, -} - -impl Token { - fn bearer_token(&self) -> String { - format!("{} {}", self.token_type, self.access_token) - } +#[derive(Debug, Clone)] +pub struct AccessToken { + pub bearer_token: String, + pub expires_at: DateTime, } impl ServiceAccount { - /// Creates a new service account from a key file and scopes - pub fn from_file(key_path: &str, scopes: Vec<&str>) -> Self { - Self { - scopes: scopes.join(" "), - key_path: key_path.to_string(), - user_email: None, - - access_token: None, - expires_at: None, - - http_client: HttpClient::new(), - } + pub fn builder() -> ServiceAccountBuilder { + ServiceAccountBuilder::new() } - /// Sets the user email - pub fn user_email(mut self, user_email: &str) -> Self { - self.user_email = Some(user_email.to_string()); - self + /// Creates a new `ServiceAccountBuilder` from a key file + pub fn from_file>( + key_path: P, + ) -> Result { + let bytes = std::fs::read(&key_path).map_err(|e| { + ServiceAccountFromFileError::ReadFile(key_path.as_ref().to_path_buf(), e) + })?; + let key = serde_json::from_slice::(&bytes) + .map_err(ServiceAccountFromFileError::DeserializeFile)?; + Ok(Self::builder().key(key)) } /// Returns an access token /// If the access token is not expired, it will return the cached access token /// Otherwise, it will exchange the JWT token for an access token - pub async fn access_token(&mut self) -> Result { - match (self.access_token.as_ref(), self.expires_at) { - (Some(access_token), Some(expires_at)) - if expires_at > Utc::now().timestamp() as u64 => - { - Ok(access_token.to_string()) - } + pub async fn access_token(&self) -> Result { + let access_token = self.access_token.read().await.clone(); + match access_token { + Some(access_token) if access_token.expires_at > Utc::now() => Ok(access_token), _ => { - let jwt_token = self.jwt_token()?; - let token = match self.exchange_jwt_token_for_access_token(jwt_token).await { - Ok(token) => token, - Err(err) => return Err(err), - }; + let new_token = self.get_fresh_access_token().await?; + *self.access_token.write().await = Some(new_token.clone()); + Ok(new_token) + } + } + } - let expires_at = Utc::now().timestamp() as u64 + token.expires_in - 30; + async fn get_fresh_access_token(&self) -> Result { + #[derive(Debug, Deserialize)] + pub struct TokenResponse { + token_type: String, + access_token: String, + expires_in: i64, + } - self.access_token = Some(token.bearer_token()); - self.expires_at = Some(expires_at); + let response = self + .http_client + .post(self.jwt_token.token_uri()) + .form(&[ + ("grant_type", "urn:ietf:params:oauth:grant-type:jwt-bearer"), + ("assertion", &self.jwt_token.sign()?), + ]) + .send() + .await + .map_err(GetAccessTokenError::HttpRequest)?; + + if !response.status().is_success() { + return Err(GetAccessTokenError::HttpRequestUnsuccessful( + response.status(), + response.text().await, + )); + } - Ok(token.bearer_token()) - } + let json = response + .json::() + .await + .map_err(GetAccessTokenError::HttpJson)?; + + if json.token_type != "Bearer" { + return Err(GetAccessTokenError::AccessTokenNotBearer(json.token_type)); } - } - async fn exchange_jwt_token_for_access_token( - &mut self, - jwt_token: jwt::JwtToken, - ) -> Result { - let req_builder = self.http_client.post(jwt_token.token_uri()).form(&[ - ("grant_type", "urn:ietf:params:oauth:grant-type:jwt-bearer"), - ("assertion", &jwt_token.to_string()?), - ]); - - let res = match req_builder.send().await { - Ok(resp) => resp, - Err(err) => return Err(ServiceAccountError::HttpReqwest(err)), - }; - - let token = match res.json::().await { - Ok(token) => token, - Err(err) => return Err(ServiceAccountError::HttpReqwest(err)), - }; - - Ok(token) + // Account for clock skew or time to receive or process the response + const LEEWAY: Duration = Duration::seconds(30); + + let expires_at = Utc::now() + Duration::seconds(json.expires_in) - LEEWAY; + + Ok(AccessToken { + bearer_token: json.access_token, + expires_at, + }) } +} - fn jwt_token(&self) -> Result { - let token = jwt::JwtToken::from_file(&self.key_path)?; +pub struct ServiceAccountBuilder { + http_client: Option, + key: Option, + scopes: Option, + user_email: Option, +} - Ok(match self.user_email { - Some(ref user_email) => token.sub(user_email.to_string()), - None => token, +impl ServiceAccountBuilder { + pub fn new() -> Self { + Self { + http_client: None, + key: None, + scopes: None, + user_email: None, } - .scope(self.scopes.clone())) + } + + /// Panics if key is not provided + pub fn build(self) -> Result { + let key = self.key.expect("Key required"); + let jwt_token = + jwt::JwtTokenSigner::from_key(key, self.scopes.unwrap_or_default(), self.user_email)?; + Ok(ServiceAccount { + http_client: self.http_client.unwrap_or_default(), + jwt_token, + access_token: Arc::new(RwLock::new(None)), + }) + } + + pub fn http_client(mut self, http_client: HttpClient) -> Self { + self.http_client = Some(http_client); + self + } + + pub fn key(mut self, key: ServiceAccountKey) -> Self { + self.key = Some(key); + self + } + + pub fn scopes(mut self, scopes: Vec<&str>) -> Self { + self.scopes = Some(scopes.join(" ")); + self + } + + pub fn user_email>(mut self, user_email: S) -> Self { + self.user_email = Some(user_email.into()); + self + } +} + +impl Default for ServiceAccountBuilder { + fn default() -> Self { + Self::new() } } @@ -118,25 +169,28 @@ mod tests { use super::*; #[tokio::test] - async fn test_access_token() { + async fn test_access_token_cache() { let scopes = vec!["https://www.googleapis.com/auth/drive"]; let key_path = "test_fixtures/service-account-key.json"; - let mut service_account = ServiceAccount::from_file(key_path, scopes); - - // TODO: fix this test - make sure we can run an integration test - // let access_token = service_account.access_token(); - // assert!(access_token.is_ok()); - // assert!(!access_token.unwrap().is_empty()); - - service_account.access_token = Some("test_access_token".to_string()); - - let expires_at = Utc::now().timestamp() as u64 + 3600; - service_account.expires_at = Some(expires_at); + let service_account = ServiceAccount::from_file(key_path) + .unwrap() + .scopes(scopes) + .build() + .unwrap(); + + let expires_at = Utc::now() + Duration::seconds(3600); + *service_account.access_token.write().await = Some(AccessToken { + bearer_token: "test_access_token".to_string(), + expires_at, + }); assert_eq!( - service_account.access_token().await.unwrap(), + service_account.access_token().await.unwrap().bearer_token, "test_access_token" ); - assert_eq!(service_account.expires_at.unwrap(), expires_at); + assert_eq!( + service_account.access_token().await.unwrap().expires_at, + expires_at + ); } } diff --git a/src/token_provider/errors.rs b/src/token_provider/errors.rs index cf1fab8..11cba56 100644 --- a/src/token_provider/errors.rs +++ b/src/token_provider/errors.rs @@ -3,7 +3,7 @@ use thiserror::Error; use tokio::sync::mpsc::error::SendError; use tokio::sync::TryLockError; -use crate::serv_account::errors::ServiceAccountError; +use crate::serv_account::errors::GetAccessTokenError; #[derive(Debug, Error)] pub enum TokenProviderError { @@ -11,7 +11,7 @@ pub enum TokenProviderError { AccessToken(#[from] TryLockError), #[error("service account error: {0}")] - ServiceAccountError(#[from] ServiceAccountError), + GetAccessTokenError(#[from] GetAccessTokenError), #[error("failed to send token: {0}")] SendError(#[from] SendError),