Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
@@ -1,2 +1,3 @@
/target
/src/proto/
/Cargo.lock
1 change: 1 addition & 0 deletions Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@ prost = "0.11.6"
reqwest = { version = "0.11.13", default-features = false, features = ["rustls-tls"] }
tokio = { version = "1", default-features = false, features = ["time"] }
rand = "0.8.5"
async-trait = "0.1.77"

[target.'cfg(genproto)'.build-dependencies]
prost-build = { version = "0.11.3" }
Expand Down
26 changes: 22 additions & 4 deletions src/client.rs
Original file line number Diff line number Diff line change
@@ -1,10 +1,12 @@
use prost::Message;
use reqwest;
use reqwest::header::CONTENT_TYPE;
use reqwest::Client;
use std::collections::HashMap;
use std::default::Default;
use std::sync::Arc;

use crate::error::VssError;
use crate::headers::{get_headermap, FixedHeaders, VssHeaderProvider};
use crate::types::{
DeleteObjectRequest, DeleteObjectResponse, GetObjectRequest, GetObjectResponse, ListKeyVersionsRequest,
ListKeyVersionsResponse, PutObjectRequest, PutObjectResponse,
Expand All @@ -23,18 +25,27 @@ where
base_url: String,
client: Client,
retry_policy: R,
header_provider: Arc<dyn VssHeaderProvider>,
}

impl<R: RetryPolicy<E = VssError>> VssClient<R> {
/// Constructs a [`VssClient`] using `base_url` as the VSS server endpoint.
pub fn new(base_url: &str, retry_policy: R) -> Self {
pub fn new(base_url: String, retry_policy: R) -> Self {
let client = Client::new();
Self::from_client(base_url, client, retry_policy)
}

/// Constructs a [`VssClient`] from a given [`reqwest::Client`], using `base_url` as the VSS server endpoint.
pub fn from_client(base_url: &str, client: Client, retry_policy: R) -> Self {
Self { base_url: String::from(base_url), client, retry_policy }
pub fn from_client(base_url: String, client: Client, retry_policy: R) -> Self {
Self { base_url, client, retry_policy, header_provider: Arc::new(FixedHeaders::new(HashMap::new())) }
}

/// Constructs a [`VssClient`] using `base_url` as the VSS server endpoint.
///
/// HTTP headers will be provided by the given `header_provider`.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: Add newline after initial paragraph to improve doc rendering.

Suggested change
/// HTTP headers will be provided by the given `header_provider`.
///
/// HTTP headers will be provided by the given `header_provider`.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done.

pub fn new_with_headers(base_url: String, retry_policy: R, header_provider: Arc<dyn VssHeaderProvider>) -> Self {
let client = Client::new();
Self { base_url, client, retry_policy, header_provider }
}

/// Returns the underlying base URL.
Expand Down Expand Up @@ -111,10 +122,17 @@ impl<R: RetryPolicy<E = VssError>> VssClient<R> {

async fn post_request<Rq: Message, Rs: Message + Default>(&self, request: &Rq, url: &str) -> Result<Rs, VssError> {
let request_body = request.encode_to_vec();
let headermap = self
.header_provider
.get_headers(&request_body)
.await
.and_then(get_headermap)
.map_err(|e| VssError::AuthError(e.to_string()))?;
let response_raw = self
.client
.post(url)
.header(CONTENT_TYPE, APPLICATION_OCTET_STREAM)
.headers(headermap)
.body(request_body)
.send()
.await?;
Expand Down
72 changes: 72 additions & 0 deletions src/headers/mod.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,72 @@
use async_trait::async_trait;
use reqwest::header::HeaderMap;
use std::collections::HashMap;
use std::error::Error;
use std::fmt::Display;
use std::fmt::Formatter;
use std::str::FromStr;

/// Defines a trait around how headers are provided for each VSS request.
#[async_trait]
pub trait VssHeaderProvider {
/// Returns the HTTP headers to be used for a VSS request.
/// This method is called on each request, and should likely perform some form of caching.
///
/// A reference to the serialized request body is given as `request`.
/// It can be used to perform operations such as request signing.
async fn get_headers(&self, request: &[u8]) -> Result<HashMap<String, String>, VssHeaderProviderError>;
}

/// Errors around providing headers for each VSS request.
#[derive(Debug)]
pub enum VssHeaderProviderError {
/// Invalid data was encountered.
InvalidData {
/// The error message.
error: String,
},
}

impl Display for VssHeaderProviderError {
fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
match self {
Self::InvalidData { error } => {
write!(f, "invalid data: {}", error)
}
}
}
}

impl Error for VssHeaderProviderError {}

/// A header provider returning an given, fixed set of headers.
pub struct FixedHeaders {
headers: HashMap<String, String>,
}

impl FixedHeaders {
/// Creates a new header provider returning the given, fixed set of headers.
pub fn new(headers: HashMap<String, String>) -> FixedHeaders {
FixedHeaders { headers }
}
}

#[async_trait]
impl VssHeaderProvider for FixedHeaders {
async fn get_headers(&self, _request: &[u8]) -> Result<HashMap<String, String>, VssHeaderProviderError> {
Ok(self.headers.clone())
}
}

pub(crate) fn get_headermap(headers: HashMap<String, String>) -> Result<HeaderMap, VssHeaderProviderError> {
let mut headermap = HeaderMap::new();
for (name, value) in headers {
headermap.insert(
reqwest::header::HeaderName::from_str(&name)
.map_err(|e| VssHeaderProviderError::InvalidData { error: e.to_string() })?,
reqwest::header::HeaderValue::from_str(&value)
.map_err(|e| VssHeaderProviderError::InvalidData { error: e.to_string() })?,
);
}
Ok(headermap)
}
3 changes: 3 additions & 0 deletions src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -25,3 +25,6 @@ pub mod util;

// Encryption-Decryption related crate-only helpers.
pub(crate) mod crypto;

/// A collection of header providers.
pub mod headers;
82 changes: 71 additions & 11 deletions tests/tests.rs
Original file line number Diff line number Diff line change
@@ -1,11 +1,17 @@
#[cfg(test)]
mod tests {
use async_trait::async_trait;
use mockito::{self, Matcher};
use prost::Message;
use reqwest::header::CONTENT_TYPE;
use std::collections::HashMap;
use std::sync::Arc;
use std::time::Duration;
use vss_client::client::VssClient;
use vss_client::error::VssError;
use vss_client::headers::FixedHeaders;
use vss_client::headers::VssHeaderProvider;
use vss_client::headers::VssHeaderProviderError;

use vss_client::types::{
DeleteObjectRequest, DeleteObjectResponse, ErrorCode, ErrorResponse, GetObjectRequest, GetObjectResponse,
Expand Down Expand Up @@ -41,7 +47,42 @@ mod tests {
.create();

// Create a new VssClient with the mock server URL.
let client = VssClient::new(&base_url, retry_policy());
let client = VssClient::new(base_url, retry_policy());

let actual_result = client.get_object(&get_request).await.unwrap();

let expected_result = &mock_response;
assert_eq!(actual_result, *expected_result);

// Verify server endpoint was called exactly once.
mock_server.expect(1).assert();
}

#[tokio::test]
async fn test_get_with_headers() {
// Spin-up mock server with mock response for given request.
let base_url = mockito::server_url().to_string();

// Set up the mock request/response.
let get_request = GetObjectRequest { store_id: "store".to_string(), key: "k1".to_string() };
let mock_response = GetObjectResponse {
value: Some(KeyValue { key: "k1".to_string(), version: 2, value: b"k1v2".to_vec() }),
..Default::default()
};

// Register the mock endpoint with the mockito server and provide expected headers.
let mock_server = mockito::mock("POST", GET_OBJECT_ENDPOINT)
.match_header(CONTENT_TYPE.as_str(), APPLICATION_OCTET_STREAM)
.match_header("headerkey", "headervalue")
.match_body(get_request.encode_to_vec())
.with_status(200)
.with_body(mock_response.encode_to_vec())
.create();

// Create a new VssClient with the mock server URL and fixed headers.
let header_provider =
Arc::new(FixedHeaders::new(HashMap::from([("headerkey".to_string(), "headervalue".to_string())])));
let client = VssClient::new_with_headers(base_url, retry_policy(), header_provider);

let actual_result = client.get_object(&get_request).await.unwrap();

Expand Down Expand Up @@ -75,7 +116,7 @@ mod tests {
.create();

// Create a new VssClient with the mock server URL.
let vss_client = VssClient::new(&base_url, retry_policy());
let vss_client = VssClient::new(base_url, retry_policy());
let actual_result = vss_client.put_object(&request).await.unwrap();

let expected_result = &mock_response;
Expand Down Expand Up @@ -106,7 +147,7 @@ mod tests {
.create();

// Create a new VssClient with the mock server URL.
let vss_client = VssClient::new(&base_url, retry_policy());
let vss_client = VssClient::new(base_url, retry_policy());
let actual_result = vss_client.delete_object(&request).await.unwrap();

let expected_result = &mock_response;
Expand Down Expand Up @@ -147,7 +188,7 @@ mod tests {
.create();

// Create a new VssClient with the mock server URL.
let client = VssClient::new(&base_url, retry_policy());
let client = VssClient::new(base_url, retry_policy());

let actual_result = client.list_key_versions(&request).await.unwrap();

Expand All @@ -161,7 +202,7 @@ mod tests {
#[tokio::test]
async fn test_no_such_key_err_handling() {
let base_url = mockito::server_url();
let vss_client = VssClient::new(&base_url, retry_policy());
let vss_client = VssClient::new(base_url, retry_policy());

// NoSuchKeyError
let error_response = ErrorResponse {
Expand All @@ -185,7 +226,7 @@ mod tests {
#[tokio::test]
async fn test_get_response_without_value() {
let base_url = mockito::server_url();
let vss_client = VssClient::new(&base_url, retry_policy());
let vss_client = VssClient::new(base_url, retry_policy());

// GetObjectResponse with None value
let mock_response = GetObjectResponse { value: None, ..Default::default() };
Expand All @@ -206,7 +247,7 @@ mod tests {
#[tokio::test]
async fn test_invalid_request_err_handling() {
let base_url = mockito::server_url();
let vss_client = VssClient::new(&base_url, retry_policy());
let vss_client = VssClient::new(base_url, retry_policy());

// Invalid Request Error
let error_response = ErrorResponse {
Expand Down Expand Up @@ -258,7 +299,7 @@ mod tests {
#[tokio::test]
async fn test_auth_err_handling() {
let base_url = mockito::server_url();
let vss_client = VssClient::new(&base_url, retry_policy());
let vss_client = VssClient::new(base_url, retry_policy());

// Invalid Request Error
let error_response =
Expand Down Expand Up @@ -305,10 +346,29 @@ mod tests {
mock_server.expect(4).assert();
}

struct FailingHeaderProvider {}

#[async_trait]
impl VssHeaderProvider for FailingHeaderProvider {
async fn get_headers(&self, _request: &[u8]) -> Result<HashMap<String, String>, VssHeaderProviderError> {
Err(VssHeaderProviderError::InvalidData { error: "test".to_string() })
}
}

#[tokio::test]
async fn test_header_provider_error() {
let get_request = GetObjectRequest { store_id: "store".to_string(), key: "k1".to_string() };
let header_provider = Arc::new(FailingHeaderProvider {});
let client = VssClient::new_with_headers("notused".to_string(), retry_policy(), header_provider);
let result = client.get_object(&get_request).await;

assert!(matches!(result, Err(VssError::AuthError { .. })));
}

#[tokio::test]
async fn test_conflict_err_handling() {
let base_url = mockito::server_url();
let vss_client = VssClient::new(&base_url, retry_policy());
let vss_client = VssClient::new(base_url, retry_policy());

// Conflict Error
let error_response =
Expand All @@ -335,7 +395,7 @@ mod tests {
#[tokio::test]
async fn test_internal_server_err_handling() {
let base_url = mockito::server_url();
let vss_client = VssClient::new(&base_url, retry_policy());
let vss_client = VssClient::new(base_url, retry_policy());

// Internal Server Error
let error_response = ErrorResponse {
Expand Down Expand Up @@ -387,7 +447,7 @@ mod tests {
#[tokio::test]
async fn test_internal_err_handling() {
let base_url = mockito::server_url();
let vss_client = VssClient::new(&base_url, retry_policy());
let vss_client = VssClient::new(base_url, retry_policy());

let error_response = ErrorResponse { error_code: 999, message: "UnknownException".to_string() };
let mut _mock_server = mockito::mock("POST", Matcher::Any)
Expand Down