From b06b37cb9c01499ae5b75c06d8524f46f03be236 Mon Sep 17 00:00:00 2001 From: Willem Van Lint Date: Wed, 15 May 2024 14:39:00 -0700 Subject: [PATCH] Introduce header provider trait --- .gitignore | 1 + Cargo.toml | 1 + src/client.rs | 26 ++++++++++++--- src/headers/mod.rs | 72 ++++++++++++++++++++++++++++++++++++++++ src/lib.rs | 3 ++ tests/tests.rs | 82 +++++++++++++++++++++++++++++++++++++++------- 6 files changed, 170 insertions(+), 15 deletions(-) create mode 100644 src/headers/mod.rs diff --git a/.gitignore b/.gitignore index 1b70d9a..ed3ca36 100644 --- a/.gitignore +++ b/.gitignore @@ -1,2 +1,3 @@ /target /src/proto/ +/Cargo.lock diff --git a/Cargo.toml b/Cargo.toml index e8f8d38..0d16e40 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -16,6 +16,7 @@ prost = "0.11.6" reqwest = { version = "0.11.13", default-features = false, features = ["rustls-tls"] } tokio = { version = "1", default-features = false, features = ["time"] } rand = "0.8.5" +async-trait = "0.1.77" [target.'cfg(genproto)'.build-dependencies] prost-build = { version = "0.11.3" } diff --git a/src/client.rs b/src/client.rs index 6fcc5d2..ffe8b63 100644 --- a/src/client.rs +++ b/src/client.rs @@ -1,10 +1,12 @@ use prost::Message; -use reqwest; use reqwest::header::CONTENT_TYPE; use reqwest::Client; +use std::collections::HashMap; use std::default::Default; +use std::sync::Arc; use crate::error::VssError; +use crate::headers::{get_headermap, FixedHeaders, VssHeaderProvider}; use crate::types::{ DeleteObjectRequest, DeleteObjectResponse, GetObjectRequest, GetObjectResponse, ListKeyVersionsRequest, ListKeyVersionsResponse, PutObjectRequest, PutObjectResponse, @@ -23,18 +25,27 @@ where base_url: String, client: Client, retry_policy: R, + header_provider: Arc, } impl> VssClient { /// Constructs a [`VssClient`] using `base_url` as the VSS server endpoint. - pub fn new(base_url: &str, retry_policy: R) -> Self { + pub fn new(base_url: String, retry_policy: R) -> Self { let client = Client::new(); Self::from_client(base_url, client, retry_policy) } /// Constructs a [`VssClient`] from a given [`reqwest::Client`], using `base_url` as the VSS server endpoint. - pub fn from_client(base_url: &str, client: Client, retry_policy: R) -> Self { - Self { base_url: String::from(base_url), client, retry_policy } + pub fn from_client(base_url: String, client: Client, retry_policy: R) -> Self { + Self { base_url, client, retry_policy, header_provider: Arc::new(FixedHeaders::new(HashMap::new())) } + } + + /// Constructs a [`VssClient`] using `base_url` as the VSS server endpoint. + /// + /// HTTP headers will be provided by the given `header_provider`. + pub fn new_with_headers(base_url: String, retry_policy: R, header_provider: Arc) -> Self { + let client = Client::new(); + Self { base_url, client, retry_policy, header_provider } } /// Returns the underlying base URL. @@ -111,10 +122,17 @@ impl> VssClient { async fn post_request(&self, request: &Rq, url: &str) -> Result { let request_body = request.encode_to_vec(); + let headermap = self + .header_provider + .get_headers(&request_body) + .await + .and_then(get_headermap) + .map_err(|e| VssError::AuthError(e.to_string()))?; let response_raw = self .client .post(url) .header(CONTENT_TYPE, APPLICATION_OCTET_STREAM) + .headers(headermap) .body(request_body) .send() .await?; diff --git a/src/headers/mod.rs b/src/headers/mod.rs new file mode 100644 index 0000000..901480f --- /dev/null +++ b/src/headers/mod.rs @@ -0,0 +1,72 @@ +use async_trait::async_trait; +use reqwest::header::HeaderMap; +use std::collections::HashMap; +use std::error::Error; +use std::fmt::Display; +use std::fmt::Formatter; +use std::str::FromStr; + +/// Defines a trait around how headers are provided for each VSS request. +#[async_trait] +pub trait VssHeaderProvider { + /// Returns the HTTP headers to be used for a VSS request. + /// This method is called on each request, and should likely perform some form of caching. + /// + /// A reference to the serialized request body is given as `request`. + /// It can be used to perform operations such as request signing. + async fn get_headers(&self, request: &[u8]) -> Result, VssHeaderProviderError>; +} + +/// Errors around providing headers for each VSS request. +#[derive(Debug)] +pub enum VssHeaderProviderError { + /// Invalid data was encountered. + InvalidData { + /// The error message. + error: String, + }, +} + +impl Display for VssHeaderProviderError { + fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result { + match self { + Self::InvalidData { error } => { + write!(f, "invalid data: {}", error) + } + } + } +} + +impl Error for VssHeaderProviderError {} + +/// A header provider returning an given, fixed set of headers. +pub struct FixedHeaders { + headers: HashMap, +} + +impl FixedHeaders { + /// Creates a new header provider returning the given, fixed set of headers. + pub fn new(headers: HashMap) -> FixedHeaders { + FixedHeaders { headers } + } +} + +#[async_trait] +impl VssHeaderProvider for FixedHeaders { + async fn get_headers(&self, _request: &[u8]) -> Result, VssHeaderProviderError> { + Ok(self.headers.clone()) + } +} + +pub(crate) fn get_headermap(headers: HashMap) -> Result { + let mut headermap = HeaderMap::new(); + for (name, value) in headers { + headermap.insert( + reqwest::header::HeaderName::from_str(&name) + .map_err(|e| VssHeaderProviderError::InvalidData { error: e.to_string() })?, + reqwest::header::HeaderValue::from_str(&value) + .map_err(|e| VssHeaderProviderError::InvalidData { error: e.to_string() })?, + ); + } + Ok(headermap) +} diff --git a/src/lib.rs b/src/lib.rs index 288dca5..d418d62 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -25,3 +25,6 @@ pub mod util; // Encryption-Decryption related crate-only helpers. pub(crate) mod crypto; + +/// A collection of header providers. +pub mod headers; diff --git a/tests/tests.rs b/tests/tests.rs index f86dea0..eb9a5cf 100644 --- a/tests/tests.rs +++ b/tests/tests.rs @@ -1,11 +1,17 @@ #[cfg(test)] mod tests { + use async_trait::async_trait; use mockito::{self, Matcher}; use prost::Message; use reqwest::header::CONTENT_TYPE; + use std::collections::HashMap; + use std::sync::Arc; use std::time::Duration; use vss_client::client::VssClient; use vss_client::error::VssError; + use vss_client::headers::FixedHeaders; + use vss_client::headers::VssHeaderProvider; + use vss_client::headers::VssHeaderProviderError; use vss_client::types::{ DeleteObjectRequest, DeleteObjectResponse, ErrorCode, ErrorResponse, GetObjectRequest, GetObjectResponse, @@ -41,7 +47,42 @@ mod tests { .create(); // Create a new VssClient with the mock server URL. - let client = VssClient::new(&base_url, retry_policy()); + let client = VssClient::new(base_url, retry_policy()); + + let actual_result = client.get_object(&get_request).await.unwrap(); + + let expected_result = &mock_response; + assert_eq!(actual_result, *expected_result); + + // Verify server endpoint was called exactly once. + mock_server.expect(1).assert(); + } + + #[tokio::test] + async fn test_get_with_headers() { + // Spin-up mock server with mock response for given request. + let base_url = mockito::server_url().to_string(); + + // Set up the mock request/response. + let get_request = GetObjectRequest { store_id: "store".to_string(), key: "k1".to_string() }; + let mock_response = GetObjectResponse { + value: Some(KeyValue { key: "k1".to_string(), version: 2, value: b"k1v2".to_vec() }), + ..Default::default() + }; + + // Register the mock endpoint with the mockito server and provide expected headers. + let mock_server = mockito::mock("POST", GET_OBJECT_ENDPOINT) + .match_header(CONTENT_TYPE.as_str(), APPLICATION_OCTET_STREAM) + .match_header("headerkey", "headervalue") + .match_body(get_request.encode_to_vec()) + .with_status(200) + .with_body(mock_response.encode_to_vec()) + .create(); + + // Create a new VssClient with the mock server URL and fixed headers. + let header_provider = + Arc::new(FixedHeaders::new(HashMap::from([("headerkey".to_string(), "headervalue".to_string())]))); + let client = VssClient::new_with_headers(base_url, retry_policy(), header_provider); let actual_result = client.get_object(&get_request).await.unwrap(); @@ -75,7 +116,7 @@ mod tests { .create(); // Create a new VssClient with the mock server URL. - let vss_client = VssClient::new(&base_url, retry_policy()); + let vss_client = VssClient::new(base_url, retry_policy()); let actual_result = vss_client.put_object(&request).await.unwrap(); let expected_result = &mock_response; @@ -106,7 +147,7 @@ mod tests { .create(); // Create a new VssClient with the mock server URL. - let vss_client = VssClient::new(&base_url, retry_policy()); + let vss_client = VssClient::new(base_url, retry_policy()); let actual_result = vss_client.delete_object(&request).await.unwrap(); let expected_result = &mock_response; @@ -147,7 +188,7 @@ mod tests { .create(); // Create a new VssClient with the mock server URL. - let client = VssClient::new(&base_url, retry_policy()); + let client = VssClient::new(base_url, retry_policy()); let actual_result = client.list_key_versions(&request).await.unwrap(); @@ -161,7 +202,7 @@ mod tests { #[tokio::test] async fn test_no_such_key_err_handling() { let base_url = mockito::server_url(); - let vss_client = VssClient::new(&base_url, retry_policy()); + let vss_client = VssClient::new(base_url, retry_policy()); // NoSuchKeyError let error_response = ErrorResponse { @@ -185,7 +226,7 @@ mod tests { #[tokio::test] async fn test_get_response_without_value() { let base_url = mockito::server_url(); - let vss_client = VssClient::new(&base_url, retry_policy()); + let vss_client = VssClient::new(base_url, retry_policy()); // GetObjectResponse with None value let mock_response = GetObjectResponse { value: None, ..Default::default() }; @@ -206,7 +247,7 @@ mod tests { #[tokio::test] async fn test_invalid_request_err_handling() { let base_url = mockito::server_url(); - let vss_client = VssClient::new(&base_url, retry_policy()); + let vss_client = VssClient::new(base_url, retry_policy()); // Invalid Request Error let error_response = ErrorResponse { @@ -258,7 +299,7 @@ mod tests { #[tokio::test] async fn test_auth_err_handling() { let base_url = mockito::server_url(); - let vss_client = VssClient::new(&base_url, retry_policy()); + let vss_client = VssClient::new(base_url, retry_policy()); // Invalid Request Error let error_response = @@ -305,10 +346,29 @@ mod tests { mock_server.expect(4).assert(); } + struct FailingHeaderProvider {} + + #[async_trait] + impl VssHeaderProvider for FailingHeaderProvider { + async fn get_headers(&self, _request: &[u8]) -> Result, VssHeaderProviderError> { + Err(VssHeaderProviderError::InvalidData { error: "test".to_string() }) + } + } + + #[tokio::test] + async fn test_header_provider_error() { + let get_request = GetObjectRequest { store_id: "store".to_string(), key: "k1".to_string() }; + let header_provider = Arc::new(FailingHeaderProvider {}); + let client = VssClient::new_with_headers("notused".to_string(), retry_policy(), header_provider); + let result = client.get_object(&get_request).await; + + assert!(matches!(result, Err(VssError::AuthError { .. }))); + } + #[tokio::test] async fn test_conflict_err_handling() { let base_url = mockito::server_url(); - let vss_client = VssClient::new(&base_url, retry_policy()); + let vss_client = VssClient::new(base_url, retry_policy()); // Conflict Error let error_response = @@ -335,7 +395,7 @@ mod tests { #[tokio::test] async fn test_internal_server_err_handling() { let base_url = mockito::server_url(); - let vss_client = VssClient::new(&base_url, retry_policy()); + let vss_client = VssClient::new(base_url, retry_policy()); // Internal Server Error let error_response = ErrorResponse { @@ -387,7 +447,7 @@ mod tests { #[tokio::test] async fn test_internal_err_handling() { let base_url = mockito::server_url(); - let vss_client = VssClient::new(&base_url, retry_policy()); + let vss_client = VssClient::new(base_url, retry_policy()); let error_response = ErrorResponse { error_code: 999, message: "UnknownException".to_string() }; let mut _mock_server = mockito::mock("POST", Matcher::Any)