diff --git a/Cargo.toml b/Cargo.toml index 338d970..f1b7590 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -5,7 +5,15 @@ edition = "2021" # See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html +[features] +default = ["kms"] +kms = ["dep:aws-config", "dep:aws-sdk-kms"] + [dependencies] +aws-config = { version = "0.55.3", optional = true } +aws-sdk-kms = { version = "0.28.0", optional = true } +base64 = "0.21.2" +futures = "0.3.28" # Tokio Dependencies tokio = { version = "1.33.0", features = ["macros", "rt-multi-thread"] } axum = "0.6.20" diff --git a/config/development.toml b/config/development.toml index ac0521e..5d1539d 100644 --- a/config/development.toml +++ b/config/development.toml @@ -8,3 +8,7 @@ url = "postgres://sam:damn@localhost/locker" [secrets] tenant = "hyperswitch" master_key = "feffe9928665731c6d6a8f9467308308feffe9928665731c6d6a8f9467308308" + +[kms] +region = "us-west-2" +key_id = "abc" diff --git a/src/app.rs b/src/app.rs index 46154e9..e397039 100644 --- a/src/app.rs +++ b/src/app.rs @@ -5,6 +5,14 @@ use hyper::server::conn; use crate::{config, error, routes, storage}; +#[cfg(feature = "kms")] +use crate::crypto::{ + kms::{self, Base64Encoded, KmsData, Raw}, + Encryption, +}; +#[cfg(feature = "kms")] +use std::marker::PhantomData; + /// /// AppState: /// @@ -43,6 +51,25 @@ where impl AppState { async fn new(config: config::Config) -> error_stack::Result { + #[cfg(feature = "kms")] + { + let master_key_kms_input: KmsData = KmsData { + data: String::from_utf8(config.secrets.master_key.clone()) + .expect("Failed while converting bytes to String"), + decode_op: PhantomData, + }; + + #[allow(clippy::expect_used)] + let kms_decrypted_master_key: KmsData = kms::get_kms_client(&config.kms) + .await + .decrypt(master_key_kms_input) + .await + .expect("Failed while performing KMS decryption"); + + let mut config = config.clone(); + config.secrets.master_key = kms_decrypted_master_key.data; + } + Ok(Self { db: storage::Storage::new(config.database.url.to_owned()) .await diff --git a/src/config.rs b/src/config.rs index 54c8e95..70ad3a5 100644 --- a/src/config.rs +++ b/src/config.rs @@ -1,3 +1,6 @@ +#[cfg(feature = "kms")] +use crate::crypto::kms; + use std::path::PathBuf; #[derive(Clone, serde::Deserialize)] @@ -5,6 +8,8 @@ pub struct Config { pub server: Server, pub database: Database, pub secrets: Secrets, + #[cfg(feature = "kms")] + pub kms: kms::KmsConfig, } #[derive(Clone, serde::Deserialize)] @@ -25,19 +30,13 @@ pub struct Secrets { pub master_key: Vec, } -/// Function to deserialize hex -> Vec this is used in case of non KMS decryption fn deserialize_hex<'de, D>(deserializer: D) -> Result, D::Error> where D: serde::Deserializer<'de>, { - let hex_str: String = serde::Deserialize::deserialize(deserializer)?; - - let bytes = match hex::decode(hex_str) { - Ok(data) => data, - Err(_) => return Err(serde::de::Error::custom("Base64 decoding error")), - }; + let deserialized_str: String = serde::Deserialize::deserialize(deserializer)?; - Ok(bytes) + Ok(deserialized_str.into_bytes()) } /// Get the origin directory of the project diff --git a/src/crypto.rs b/src/crypto.rs index d353d0e..1ef9a16 100644 --- a/src/crypto.rs +++ b/src/crypto.rs @@ -4,10 +4,21 @@ /// A trait to be used internally for maintaining and managing encryption algorithms /// pub trait Encryption { - type ReturnType; - fn encrypt(&self, input: I) -> Self::ReturnType; - fn decrypt(&self, input: O) -> Self::ReturnType; + type ReturnType<'b, T> + where + Self: 'b; + fn encrypt(&self, input: I) -> Self::ReturnType<'_, O>; + fn decrypt(&self, input: O) -> Self::ReturnType<'_, I>; } pub mod aes; pub mod jw; +#[cfg(feature = "kms")] +pub mod kms; + +#[cfg(feature = "kms")] +pub mod consts { + /// General purpose base64 engine + pub(crate) const BASE64_ENGINE: base64::engine::GeneralPurpose = + base64::engine::general_purpose::STANDARD; +} diff --git a/src/crypto/aes.rs b/src/crypto/aes.rs index bfa639e..57beeed 100644 --- a/src/crypto/aes.rs +++ b/src/crypto/aes.rs @@ -71,8 +71,8 @@ impl ring::aead::NonceSequence for NonceSequence { } impl super::Encryption, Vec> for GcmAes256 { - type ReturnType = error_stack::Result; - fn encrypt(&self, mut input: Vec) -> Self::ReturnType> { + type ReturnType<'b, T> = error_stack::Result; + fn encrypt(&self, mut input: Vec) -> Self::ReturnType<'_, Vec> { let nonce_sequence = NonceSequence::new().change_context(error::CryptoError::EncryptionError)?; let current_nonce = nonce_sequence.current(); @@ -87,7 +87,7 @@ impl super::Encryption, Vec> for GcmAes256 { Ok(input) } - fn decrypt(&self, input: Vec) -> Self::ReturnType> { + fn decrypt(&self, input: Vec) -> Self::ReturnType<'_, Vec> { let key = aead::UnboundKey::new(&aead::AES_256_GCM, &self.secret) .change_context(error::CryptoError::DecryptionError)?; diff --git a/src/crypto/jw.rs b/src/crypto/jw.rs index 3b8dc1f..51f7f54 100644 --- a/src/crypto/jw.rs +++ b/src/crypto/jw.rs @@ -72,9 +72,9 @@ impl JweBody { } impl super::Encryption, Vec> for JWEncryption { - type ReturnType = Result; + type ReturnType<'a, T> = Result; - fn encrypt(&self, input: Vec) -> Self::ReturnType> { + fn encrypt(&self, input: Vec) -> Self::ReturnType<'_, Vec> { let payload = input; let jws_encoded = jws_sign_payload(&payload, self.private_key.as_bytes())?; let jws_body = JwsBody::from_str(&jws_encoded).ok_or(error::CryptoError::InvalidData( @@ -87,7 +87,7 @@ impl super::Encryption, Vec> for JWEncryption { Ok(serde_json::to_vec(&jwe_body)?) } - fn decrypt(&self, input: Vec) -> Self::ReturnType> { + fn decrypt(&self, input: Vec) -> Self::ReturnType<'_, Vec> { let jwe_body: JweBody = serde_json::from_slice(&input)?; let jwe_encoded = jwe_body.get_dotted_jwe(); let algo = jwe::RSA_OAEP; diff --git a/src/crypto/kms.rs b/src/crypto/kms.rs new file mode 100644 index 0000000..19dbfc8 --- /dev/null +++ b/src/crypto/kms.rs @@ -0,0 +1,225 @@ +use std::{marker::PhantomData, pin::Pin}; + +use aws_config::meta::region::RegionProviderChain; +use aws_sdk_kms::{config::Region, primitives::Blob, Client}; +use base64::Engine; +use error_stack::{report, ResultExt}; +use futures::Future; + +use crate::crypto::Encryption; + +use super::consts; + +static KMS_CLIENT: tokio::sync::OnceCell = tokio::sync::OnceCell::const_new(); + +/// Returns a shared KMS client, or initializes a new one if not previously initialized. +#[inline] +pub async fn get_kms_client(config: &KmsConfig) -> &'static KmsClient { + KMS_CLIENT.get_or_init(|| KmsClient::new(config)).await +} + +/// Configuration parameters required for constructing a [`KmsClient`]. +#[derive(Clone, Debug, Default, serde::Deserialize)] +#[serde(default)] +pub struct KmsConfig { + /// The AWS key identifier of the KMS key used to encrypt or decrypt data. + pub key_id: String, + + /// The AWS region to send KMS requests to. + pub region: String, +} + +/// Client for KMS operations. +#[derive(Debug)] +pub struct KmsClient { + inner_client: Client, + key_id: String, +} + +impl KmsClient { + /// Constructs a new KMS client. + pub async fn new(config: &KmsConfig) -> Self { + let region_provider = RegionProviderChain::first_try(Region::new(config.region.clone())); + let sdk_config = aws_config::from_env().region(region_provider).load().await; + + Self { + inner_client: Client::new(&sdk_config), + key_id: config.key_id.clone(), + } + } +} + +/// Errors that could occur during KMS operations. +#[derive(Debug, thiserror::Error)] +pub enum KmsError { + /// An error occurred when base64 decoding input data. + #[error("Failed to base64 decode input data")] + Base64DecodingFailed, + + /// An error occurred when hex decoding input data. + #[error("Failed to hex decode input data")] + HexDecodingFailed, + + /// An error occurred when KMS decrypting input data. + #[error("Failed to KMS decrypt input data")] + DecryptionFailed, + + /// The KMS decrypted output does not include a plaintext output. + #[error("Missing plaintext KMS decryption output")] + MissingPlaintextDecryptionOutput, + + /// An error occurred UTF-8 decoding KMS decrypted output. + #[error("Failed to UTF-8 decode decryption output")] + Utf8DecodingFailed, + + /// The KMS client has not been initialized. + #[error("The KMS client has not been initialized")] + KmsClientNotInitialized, +} + +impl KmsConfig { + /// Verifies that the [`KmsClient`] configuration is usable. + pub fn validate(&self) -> Result<(), &'static str> { + if self.key_id.trim().is_empty() { + return Err("KMS AWS key ID must not be empty"); + }; + + if self.region.trim().is_empty() { + return Err("KMS AWS region must not be empty"); + } + + Ok(()) + } +} + +#[derive(Clone, Debug, Default, serde::Deserialize, Eq, PartialEq)] + +pub struct KmsData { + pub data: T::Data, + pub decode_op: PhantomData, +} + +impl KmsData { + pub fn into_decoded(self) -> Result, T::Error> { + T::decode(self.data) + } + pub fn encode(data: Vec) -> Result { + Ok(Self { + data: T::encode(data)?, + decode_op: PhantomData, + }) + } +} + +pub trait Decoder { + type Data; + type Error; + fn encode(input: Vec) -> Result; + fn decode(input: Self::Data) -> Result, Self::Error>; +} + +pub struct StringEncoded; + +impl Decoder for StringEncoded { + type Data = String; + type Error = error_stack::Report; + + fn encode(input: Vec) -> Result { + String::from_utf8(input).change_context(KmsError::Utf8DecodingFailed) + } + fn decode(input: Self::Data) -> Result, Self::Error> { + Ok(input.into_bytes()) + } +} + +pub struct Base64Encoded; + +impl Decoder for Base64Encoded { + type Data = String; + type Error = error_stack::Report; + + fn encode(input: Vec) -> Result { + Ok(consts::BASE64_ENGINE.encode(input)) + } + fn decode(input: Self::Data) -> Result, Self::Error> { + consts::BASE64_ENGINE + .decode(input) + .change_context(KmsError::Base64DecodingFailed) + } +} + +pub struct HexEncoded; + +impl Decoder for HexEncoded { + type Data = String; + type Error = error_stack::Report; + + fn encode(input: Vec) -> Result { + Ok(hex::encode(input)) + } + fn decode(input: Self::Data) -> Result, Self::Error> { + hex::decode(input).change_context(KmsError::HexDecodingFailed) + } +} + +impl>> + Encryption, KmsData> for KmsClient +{ + type ReturnType<'b, T> = Pin> + 'b>>; + + fn encrypt(&self, _input: KmsData) -> Self::ReturnType<'_, KmsData> { + todo!() + } + + fn decrypt<'a>( + &'a self, + input: KmsData, + ) -> Pin, KmsError>> + 'a>> { + Box::pin(async move { + let data = input.into_decoded()?; + let ciphertext_blob = Blob::new(data.clone()); + + let decrypt_output = self + .inner_client + .decrypt() + .key_id(&self.key_id) + .ciphertext_blob(ciphertext_blob) + .send() + .await + .map_err(|error| { + println!("Failed to KMS decrypt data: {error:?}"); + error + }) + .change_context(KmsError::DecryptionFailed)?; + + let output = decrypt_output + .plaintext + .ok_or(report!(KmsError::MissingPlaintextDecryptionOutput)) + .and_then(|blob| { + String::from_utf8(blob.into_inner()) + .change_context(KmsError::Utf8DecodingFailed) + })?; + let decoded_output = consts::BASE64_ENGINE + .decode(output) + .change_context(KmsError::Base64DecodingFailed)?; + + KmsData::encode(decoded_output) + }) + } +} + +pub struct Raw; + +impl Decoder for Raw { + type Data = Vec; + + type Error = error_stack::Report; + + fn encode(input: Vec) -> Result { + Ok(input) + } + + fn decode(input: Self::Data) -> Result, Self::Error> { + Ok(input) + } +} diff --git a/src/storage/types.rs b/src/storage/types.rs index 6428094..791d75f 100644 --- a/src/storage/types.rs +++ b/src/storage/types.rs @@ -155,7 +155,7 @@ pub(super) trait StorageDecryption: Sized { fn decrypt( self, algo: &Self::Algorithm, - ) -> , Vec>>::ReturnType; + ) -> , Vec>>::ReturnType<'_, Self::Output>; } pub(super) trait StorageEncryption: Sized { @@ -164,7 +164,7 @@ pub(super) trait StorageEncryption: Sized { fn encrypt( self, algo: &Self::Algorithm, - ) -> , Vec>>::ReturnType; + ) -> , Vec>>::ReturnType<'_, Self::Output>; } impl StorageDecryption for MerchantInner { @@ -175,7 +175,8 @@ impl StorageDecryption for MerchantInner { fn decrypt( self, algo: &Self::Algorithm, - ) -> , Vec>>::ReturnType { + ) -> , Vec>>::ReturnType<'_, Self::Output> + { Ok(Self::Output { merchant_id: self.merchant_id, enc_key: algo.decrypt(self.enc_key.into_inner().expose())?.into(), @@ -193,7 +194,8 @@ impl StorageEncryption for MerchantNew { fn encrypt( self, algo: &Self::Algorithm, - ) -> , Vec>>::ReturnType { + ) -> , Vec>>::ReturnType<'_, Self::Output> + { Ok(Self::Output { merchant_id: self.merchant_id, enc_key: algo.encrypt(self.enc_key.expose())?.into(), @@ -210,7 +212,8 @@ impl StorageDecryption for LockerInner { fn decrypt( self, algo: &Self::Algorithm, - ) -> , Vec>>::ReturnType { + ) -> , Vec>>::ReturnType<'_, Self::Output> + { Ok(Self::Output { locker_id: self.locker_id, tenant_id: self.tenant_id, @@ -230,7 +233,8 @@ impl StorageEncryption for LockerNew { fn encrypt( self, algo: &Self::Algorithm, - ) -> , Vec>>::ReturnType { + ) -> , Vec>>::ReturnType<'_, Self::Output> + { Ok(Self::Output { locker_id: self.locker_id, tenant_id: self.tenant_id,