Skip to content

Commit

Permalink
feat(kms): integrate kms feature
Browse files Browse the repository at this point in the history
  • Loading branch information
Chethan-rao committed Oct 26, 2023
1 parent 5d9ab51 commit ead558d
Show file tree
Hide file tree
Showing 9 changed files with 301 additions and 23 deletions.
8 changes: 8 additions & 0 deletions Cargo.toml
Expand Up @@ -5,7 +5,15 @@ edition = "2021"

# See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html

[features]
default = ["kms"]
kms = ["dep:aws-config", "dep:aws-sdk-kms"]

[dependencies]
aws-config = { version = "0.55.3", optional = true }
aws-sdk-kms = { version = "0.28.0", optional = true }
base64 = "0.21.2"
futures = "0.3.28"
# Tokio Dependencies
tokio = { version = "1.33.0", features = ["macros", "rt-multi-thread"] }
axum = "0.6.20"
Expand Down
4 changes: 4 additions & 0 deletions config/development.toml
Expand Up @@ -8,3 +8,7 @@ url = "postgres://sam:damn@localhost/locker"
[secrets]
tenant = "hyperswitch"
master_key = "feffe9928665731c6d6a8f9467308308feffe9928665731c6d6a8f9467308308"

[kms]
region = "us-west-2"
key_id = "abc"
27 changes: 27 additions & 0 deletions src/app.rs
Expand Up @@ -5,6 +5,14 @@ use hyper::server::conn;

use crate::{config, error, routes, storage};

#[cfg(feature = "kms")]
use crate::crypto::{
kms::{self, Base64Encoded, KmsData, Raw},
Encryption,
};
#[cfg(feature = "kms")]
use std::marker::PhantomData;

///
/// AppState:
///
Expand Down Expand Up @@ -43,6 +51,25 @@ where

impl AppState {
async fn new(config: config::Config) -> error_stack::Result<Self, error::ConfigurationError> {
#[cfg(feature = "kms")]
{
let master_key_kms_input: KmsData<Base64Encoded> = KmsData {
data: String::from_utf8(config.secrets.master_key.clone())
.expect("Failed while converting bytes to String"),
decode_op: PhantomData,
};

#[allow(clippy::expect_used)]
let kms_decrypted_master_key: KmsData<Raw> = kms::get_kms_client(&config.kms)
.await
.decrypt(master_key_kms_input)
.await
.expect("Failed while performing KMS decryption");

let mut config = config.clone();
config.secrets.master_key = kms_decrypted_master_key.data;
}

Ok(Self {
db: storage::Storage::new(config.database.url.to_owned())
.await
Expand Down
15 changes: 7 additions & 8 deletions src/config.rs
@@ -1,10 +1,15 @@
#[cfg(feature = "kms")]
use crate::crypto::kms;

use std::path::PathBuf;

#[derive(Clone, serde::Deserialize)]
pub struct Config {
pub server: Server,
pub database: Database,
pub secrets: Secrets,
#[cfg(feature = "kms")]
pub kms: kms::KmsConfig,
}

#[derive(Clone, serde::Deserialize)]
Expand All @@ -25,19 +30,13 @@ pub struct Secrets {
pub master_key: Vec<u8>,
}

/// Function to deserialize hex -> Vec<u8> this is used in case of non KMS decryption
fn deserialize_hex<'de, D>(deserializer: D) -> Result<Vec<u8>, D::Error>
where
D: serde::Deserializer<'de>,
{
let hex_str: String = serde::Deserialize::deserialize(deserializer)?;

let bytes = match hex::decode(hex_str) {
Ok(data) => data,
Err(_) => return Err(serde::de::Error::custom("Base64 decoding error")),
};
let deserialized_str: String = serde::Deserialize::deserialize(deserializer)?;

Ok(bytes)
Ok(deserialized_str.into_bytes())
}

/// Get the origin directory of the project
Expand Down
17 changes: 14 additions & 3 deletions src/crypto.rs
Expand Up @@ -4,10 +4,21 @@
/// A trait to be used internally for maintaining and managing encryption algorithms
///
pub trait Encryption<I, O> {
type ReturnType<T>;
fn encrypt(&self, input: I) -> Self::ReturnType<O>;
fn decrypt(&self, input: O) -> Self::ReturnType<I>;
type ReturnType<'b, T>
where
Self: 'b;
fn encrypt(&self, input: I) -> Self::ReturnType<'_, O>;
fn decrypt(&self, input: O) -> Self::ReturnType<'_, I>;
}

pub mod aes;
pub mod jw;
#[cfg(feature = "kms")]
pub mod kms;

#[cfg(feature = "kms")]
pub mod consts {
/// General purpose base64 engine
pub(crate) const BASE64_ENGINE: base64::engine::GeneralPurpose =
base64::engine::general_purpose::STANDARD;
}
6 changes: 3 additions & 3 deletions src/crypto/aes.rs
Expand Up @@ -71,8 +71,8 @@ impl ring::aead::NonceSequence for NonceSequence {
}

impl super::Encryption<Vec<u8>, Vec<u8>> for GcmAes256 {
type ReturnType<T> = error_stack::Result<T, error::CryptoError>;
fn encrypt(&self, mut input: Vec<u8>) -> Self::ReturnType<Vec<u8>> {
type ReturnType<'b, T> = error_stack::Result<T, error::CryptoError>;
fn encrypt(&self, mut input: Vec<u8>) -> Self::ReturnType<'_, Vec<u8>> {
let nonce_sequence =
NonceSequence::new().change_context(error::CryptoError::EncryptionError)?;
let current_nonce = nonce_sequence.current();
Expand All @@ -87,7 +87,7 @@ impl super::Encryption<Vec<u8>, Vec<u8>> for GcmAes256 {
Ok(input)
}

fn decrypt(&self, input: Vec<u8>) -> Self::ReturnType<Vec<u8>> {
fn decrypt(&self, input: Vec<u8>) -> Self::ReturnType<'_, Vec<u8>> {
let key = aead::UnboundKey::new(&aead::AES_256_GCM, &self.secret)
.change_context(error::CryptoError::DecryptionError)?;

Expand Down
6 changes: 3 additions & 3 deletions src/crypto/jw.rs
Expand Up @@ -72,9 +72,9 @@ impl JweBody {
}

impl super::Encryption<Vec<u8>, Vec<u8>> for JWEncryption {
type ReturnType<T> = Result<T, error::CryptoError>;
type ReturnType<'a, T> = Result<T, error::CryptoError>;

fn encrypt(&self, input: Vec<u8>) -> Self::ReturnType<Vec<u8>> {
fn encrypt(&self, input: Vec<u8>) -> Self::ReturnType<'_, Vec<u8>> {
let payload = input;
let jws_encoded = jws_sign_payload(&payload, self.private_key.as_bytes())?;
let jws_body = JwsBody::from_str(&jws_encoded).ok_or(error::CryptoError::InvalidData(
Expand All @@ -87,7 +87,7 @@ impl super::Encryption<Vec<u8>, Vec<u8>> for JWEncryption {
Ok(serde_json::to_vec(&jwe_body)?)
}

fn decrypt(&self, input: Vec<u8>) -> Self::ReturnType<Vec<u8>> {
fn decrypt(&self, input: Vec<u8>) -> Self::ReturnType<'_, Vec<u8>> {
let jwe_body: JweBody = serde_json::from_slice(&input)?;
let jwe_encoded = jwe_body.get_dotted_jwe();
let algo = jwe::RSA_OAEP;
Expand Down
225 changes: 225 additions & 0 deletions src/crypto/kms.rs
@@ -0,0 +1,225 @@
use std::{marker::PhantomData, pin::Pin};

use aws_config::meta::region::RegionProviderChain;
use aws_sdk_kms::{config::Region, primitives::Blob, Client};
use base64::Engine;
use error_stack::{report, ResultExt};
use futures::Future;

use crate::crypto::Encryption;

use super::consts;

static KMS_CLIENT: tokio::sync::OnceCell<KmsClient> = tokio::sync::OnceCell::const_new();

/// Returns a shared KMS client, or initializes a new one if not previously initialized.
#[inline]
pub async fn get_kms_client(config: &KmsConfig) -> &'static KmsClient {
KMS_CLIENT.get_or_init(|| KmsClient::new(config)).await
}

/// Configuration parameters required for constructing a [`KmsClient`].
#[derive(Clone, Debug, Default, serde::Deserialize)]
#[serde(default)]
pub struct KmsConfig {
/// The AWS key identifier of the KMS key used to encrypt or decrypt data.
pub key_id: String,

/// The AWS region to send KMS requests to.
pub region: String,
}

/// Client for KMS operations.
#[derive(Debug)]
pub struct KmsClient {
inner_client: Client,
key_id: String,
}

impl KmsClient {
/// Constructs a new KMS client.
pub async fn new(config: &KmsConfig) -> Self {
let region_provider = RegionProviderChain::first_try(Region::new(config.region.clone()));
let sdk_config = aws_config::from_env().region(region_provider).load().await;

Self {
inner_client: Client::new(&sdk_config),
key_id: config.key_id.clone(),
}
}
}

/// Errors that could occur during KMS operations.
#[derive(Debug, thiserror::Error)]
pub enum KmsError {
/// An error occurred when base64 decoding input data.
#[error("Failed to base64 decode input data")]
Base64DecodingFailed,

/// An error occurred when hex decoding input data.
#[error("Failed to hex decode input data")]
HexDecodingFailed,

/// An error occurred when KMS decrypting input data.
#[error("Failed to KMS decrypt input data")]
DecryptionFailed,

/// The KMS decrypted output does not include a plaintext output.
#[error("Missing plaintext KMS decryption output")]
MissingPlaintextDecryptionOutput,

/// An error occurred UTF-8 decoding KMS decrypted output.
#[error("Failed to UTF-8 decode decryption output")]
Utf8DecodingFailed,

/// The KMS client has not been initialized.
#[error("The KMS client has not been initialized")]
KmsClientNotInitialized,
}

impl KmsConfig {
/// Verifies that the [`KmsClient`] configuration is usable.
pub fn validate(&self) -> Result<(), &'static str> {
if self.key_id.trim().is_empty() {
return Err("KMS AWS key ID must not be empty");
};

if self.region.trim().is_empty() {
return Err("KMS AWS region must not be empty");
}

Ok(())
}
}

#[derive(Clone, Debug, Default, serde::Deserialize, Eq, PartialEq)]

pub struct KmsData<T: Decoder> {
pub data: T::Data,
pub decode_op: PhantomData<T>,
}

impl<T: Decoder> KmsData<T> {
pub fn into_decoded(self) -> Result<Vec<u8>, T::Error> {
T::decode(self.data)
}
pub fn encode(data: Vec<u8>) -> Result<Self, T::Error> {
Ok(Self {
data: T::encode(data)?,
decode_op: PhantomData,
})
}
}

pub trait Decoder {
type Data;
type Error;
fn encode(input: Vec<u8>) -> Result<Self::Data, Self::Error>;
fn decode(input: Self::Data) -> Result<Vec<u8>, Self::Error>;
}

pub struct StringEncoded;

impl Decoder for StringEncoded {
type Data = String;
type Error = error_stack::Report<KmsError>;

fn encode(input: Vec<u8>) -> Result<Self::Data, Self::Error> {
String::from_utf8(input).change_context(KmsError::Utf8DecodingFailed)
}
fn decode(input: Self::Data) -> Result<Vec<u8>, Self::Error> {
Ok(input.into_bytes())
}
}

pub struct Base64Encoded;

impl Decoder for Base64Encoded {
type Data = String;
type Error = error_stack::Report<KmsError>;

fn encode(input: Vec<u8>) -> Result<Self::Data, Self::Error> {
Ok(consts::BASE64_ENGINE.encode(input))
}
fn decode(input: Self::Data) -> Result<Vec<u8>, Self::Error> {
consts::BASE64_ENGINE
.decode(input)
.change_context(KmsError::Base64DecodingFailed)
}
}

pub struct HexEncoded;

impl Decoder for HexEncoded {
type Data = String;
type Error = error_stack::Report<KmsError>;

fn encode(input: Vec<u8>) -> Result<Self::Data, Self::Error> {
Ok(hex::encode(input))
}
fn decode(input: Self::Data) -> Result<Vec<u8>, Self::Error> {
hex::decode(input).change_context(KmsError::HexDecodingFailed)
}
}

impl<U: Decoder<Error = error_stack::Report<KmsError>>>
Encryption<KmsData<U>, KmsData<Base64Encoded>> for KmsClient
{
type ReturnType<'b, T> = Pin<Box<dyn Future<Output = error_stack::Result<T, KmsError>> + 'b>>;

fn encrypt(&self, _input: KmsData<U>) -> Self::ReturnType<'_, KmsData<Base64Encoded>> {
todo!()
}

fn decrypt<'a>(
&'a self,
input: KmsData<Base64Encoded>,
) -> Pin<Box<dyn Future<Output = error_stack::Result<KmsData<U>, KmsError>> + 'a>> {
Box::pin(async move {
let data = input.into_decoded()?;
let ciphertext_blob = Blob::new(data.clone());

let decrypt_output = self
.inner_client
.decrypt()
.key_id(&self.key_id)
.ciphertext_blob(ciphertext_blob)
.send()
.await
.map_err(|error| {
println!("Failed to KMS decrypt data: {error:?}");
error
})
.change_context(KmsError::DecryptionFailed)?;

let output = decrypt_output
.plaintext
.ok_or(report!(KmsError::MissingPlaintextDecryptionOutput))
.and_then(|blob| {
String::from_utf8(blob.into_inner())
.change_context(KmsError::Utf8DecodingFailed)
})?;
let decoded_output = consts::BASE64_ENGINE
.decode(output)
.change_context(KmsError::Base64DecodingFailed)?;

KmsData::encode(decoded_output)
})
}
}

pub struct Raw;

impl Decoder for Raw {
type Data = Vec<u8>;

type Error = error_stack::Report<KmsError>;

fn encode(input: Vec<u8>) -> Result<Self::Data, Self::Error> {
Ok(input)
}

fn decode(input: Self::Data) -> Result<Vec<u8>, Self::Error> {
Ok(input)
}
}

0 comments on commit ead558d

Please sign in to comment.