Skip to content

Commit

Permalink
[wip] refacto client
Browse files Browse the repository at this point in the history
  • Loading branch information
ghubertpalo committed May 31, 2023
1 parent 0125365 commit 7840930
Show file tree
Hide file tree
Showing 10 changed files with 570 additions and 77 deletions.
4 changes: 3 additions & 1 deletion mithril-client/src/aggregator.rs
Expand Up @@ -284,7 +284,7 @@ impl AggregatorHandler for AggregatorHTTPClient {
}
}

/// Download Snapshot
/// Download Snapshot. Returns the path of the downloaded file as String.
async fn download_snapshot(
&self,
digest: &str,
Expand All @@ -310,6 +310,7 @@ impl AggregatorHandler for AggregatorHTTPClient {
})?;
let mut bytes_downloaded = 0;
let mut remote_stream = response.bytes_stream();

while let Some(item) = remote_stream.next().await {
let chunk = item.map_err(|e| {
AggregatorHandlerError::RemoteServerTechnical(e.to_string())
Expand Down Expand Up @@ -362,6 +363,7 @@ impl AggregatorHandler for AggregatorHTTPClient {
let unpack_dir_path = local_path.parent().unwrap().join(path::Path::new("db"));
let mut snapshot_archive = Archive::new(snapshot_file_tar);
snapshot_archive.unpack(&unpack_dir_path)?;

Ok(unpack_dir_path.into_os_string().into_string().unwrap())
}

Expand Down
193 changes: 193 additions & 0 deletions mithril-client/src/aggregator_client/http_client.rs
@@ -0,0 +1,193 @@
use std::{path::Path, sync::Arc};

use async_recursion::async_recursion;
use async_trait::async_trait;
use futures::StreamExt;
use reqwest::{Client, Response, StatusCode};
use semver::Version;
use slog_scope::debug;
use thiserror::Error;
use tokio::{fs, io::AsyncWriteExt, sync::RwLock};

#[cfg(test)]
use mockall::automock;

use mithril_common::{StdError, MITHRIL_API_VERSION_HEADER};

#[derive(Error, Debug)]
pub enum AggregatorHTTPClientError {
/// Error raised when querying the aggregator returned a 5XX error.
#[error("remote server technical error: '{0}'")]
RemoteServerTechnical(String),

/// Error raised when querying the aggregator returned a 4XX error.
#[error("remote server logical error: '{0}'")]
RemoteServerLogical(String),

/// Error raised when the aggregator can't be reached.
#[error("remote server unreachable: '{0}'")]
RemoteServerUnreachable(String),

/// Error raised when the server API version mismatch the client API version.
#[error("API version mismatch: {0}")]
ApiVersionMismatch(String),

/// HTTP subsystem error
#[error("HTTP subsystem error: {message} ({error}).")]
SubsystemError { message: String, error: StdError },
}

#[async_trait]
pub trait AggregatorClient {
async fn get_json(&self, url: &str) -> Result<String, AggregatorHTTPClientError>;

async fn download(&self, url: &str, filepath: &Path) -> Result<(), AggregatorHTTPClientError>;
}

/// Responsible of HTTP transport and API version check.
pub struct AggregatorHTTPClient {
network: String,
aggregator_endpoint: String,
api_versions: Arc<RwLock<Vec<Version>>>,
}

impl AggregatorHTTPClient {
/// AggregatorHTTPClient factory
pub fn new(network: String, aggregator_endpoint: String, api_versions: Vec<Version>) -> Self {
debug!("New AggregatorHTTPClient created");
Self {
network,
aggregator_endpoint,
api_versions: Arc::new(RwLock::new(api_versions)),
}
}

/// Computes the current api version
async fn compute_current_api_version(&self) -> Option<Version> {
self.api_versions.read().await.first().cloned()
}

/// Discards the current api version
/// It discards the current version if and only if there is at least 2 versions available
async fn discard_current_api_version(&self) -> Option<Version> {
if self.api_versions.read().await.len() < 2 {
return None;
}
if let Some(current_api_version) = self.compute_current_api_version().await {
let mut api_versions = self.api_versions.write().await;
if let Some(index) = api_versions
.iter()
.position(|value| *value == current_api_version)
{
api_versions.remove(index);
return Some(current_api_version);
}
}
None
}

/// Perform a HTTP GET request on the Aggregator and return the given JSON
#[async_recursion]
async fn get(&self, url: &str) -> Result<Response, AggregatorHTTPClientError> {
let request_builder = Client::new().get(url.to_owned());
let current_api_version = self
.compute_current_api_version()
.await
.unwrap()
.to_string();
debug!("Prepare request with version: {}", current_api_version);
let request_builder =
request_builder.header(MITHRIL_API_VERSION_HEADER, current_api_version);
let response = request_builder.send().await.map_err(|e| {
AggregatorHTTPClientError::SubsystemError {
message: format!(
"Cannot perform a GET against the Aggregator HTTP server (url='{url}')"
),
error: e.into(),
}
})?;

match response.status() {
StatusCode::OK => Ok(response),
StatusCode::PRECONDITION_FAILED => {
if self.discard_current_api_version().await.is_some()
&& !self.api_versions.read().await.is_empty()
{
return self.get(url).await;
}

return Err(self.handle_api_error(&response).await);
}
StatusCode::NOT_FOUND => Err(AggregatorHTTPClientError::RemoteServerLogical(format!(
"Url='{url} not found"
))),
status_code => Err(AggregatorHTTPClientError::RemoteServerTechnical(format!(
"Unhandled error {status_code}"
))),
}
}

/// API version error handling
async fn handle_api_error(&self, response: &Response) -> AggregatorHTTPClientError {
if let Some(version) = response.headers().get(MITHRIL_API_VERSION_HEADER) {
AggregatorHTTPClientError::ApiVersionMismatch(format!(
"server version: '{}', signer version: '{}'",
version.to_str().unwrap(),
self.compute_current_api_version().await.unwrap()
))
} else {
AggregatorHTTPClientError::ApiVersionMismatch(format!(
"version precondition failed, sent version '{}'.",
self.compute_current_api_version().await.unwrap()
))
}
}
}

#[cfg_attr(test, automock)]
#[async_trait]
impl AggregatorClient for AggregatorHTTPClient {
async fn get_json(&self, url: &str) -> Result<String, AggregatorHTTPClientError> {
let url = format!("{}/{}", self.aggregator_endpoint, url);
let response = self.get(&url).await?;

response
.json()
.await
.map_err(|e| AggregatorHTTPClientError::SubsystemError {
message: format!("Could not find a JSON body in the response {response:?}"),
error: e.into(),
})
}

async fn download(&self, url: &str, filepath: &Path) -> Result<(), AggregatorHTTPClientError> {
let url = format!("{}/{}", self.aggregator_endpoint, url);
let response = self.get(&url).await?;
let mut local_file = fs::File::create(filepath).await.map_err(|e| {
AggregatorHTTPClientError::SubsystemError {
message: format!(
"Could not create download archive '{}'.",
filepath.display()
),
error: e.into(),
}
})?;
let bytes_total = response.content_length().ok_or_else(|| {
AggregatorHTTPClientError::RemoteServerTechnical(
"cannot get response content length".to_string(),
)
})?;
let mut remote_stream = response.bytes_stream();

while let Some(item) = remote_stream.next().await {
let chunk = item.map_err(|e| {
AggregatorHTTPClientError::RemoteServerTechnical(format!(
"Download: Could not read from byte stream: {e}"
))
})?;
local_file.write_all(&chunk);
}

Ok(())
}
}
5 changes: 5 additions & 0 deletions mithril-client/src/aggregator_client/mod.rs
@@ -0,0 +1,5 @@
mod http_client;
mod snapshot_client;

pub use http_client::*;
pub use snapshot_client::*;
82 changes: 82 additions & 0 deletions mithril-client/src/aggregator_client/snapshot_client.rs
@@ -0,0 +1,82 @@
//! This module contains a struct to exchange snapshot information with the Aggregator

use std::{
path::{Path, PathBuf},
sync::Arc,
};

use mithril_common::{
entities::Snapshot,
messages::{SnapshotListMessage, SnapshotMessage},
StdError, StdResult,
};
use slog_scope::warn;
use thiserror::Error;

use crate::{FromSnapshotListMessageAdapter, FromSnapshotMessageAdapter};

use super::AggregatorClient;

#[derive(Error, Debug)]
pub enum SnapshotClientError {
#[error("Could not find a working download location for the snapshot digest '{digest}', tried location: {{'{locations}'}}.")]
NoWorkingLocation { digest: String, locations: String },

#[error("subsystem error: '{message}'; nested error: {error}")]
SubsystemError { message: String, error: StdError },
}

pub struct SnapshotClient {
http_client: Arc<dyn AggregatorClient>,
download_dir: PathBuf,
}

impl SnapshotClient {
pub fn new(http_client: Arc<dyn AggregatorClient>, download_dir: &Path) -> Self {
Self {
http_client,
download_dir: download_dir.to_owned(),
}
}

pub async fn list(&self) -> StdResult<Vec<Snapshot>> {
let url = "/artifact/snapshots";
let response = self.http_client.get_json(url).await?;
let message = serde_json::from_str::<SnapshotListMessage>(&response)?;
let snapshots = FromSnapshotListMessageAdapter::adapt(message);

Ok(snapshots)
}

pub async fn show(&self, digest: &str) -> StdResult<Snapshot> {
let url = format!("/artifact/snapshot/{}", digest);
let response = self.http_client.get_json(&url).await?;
let message = serde_json::from_str::<SnapshotMessage>(&response)?;
let snapshot = FromSnapshotMessageAdapter::adapt(message);

Ok(snapshot)
}

pub async fn download(&self, snapshot: &Snapshot) -> StdResult<String> {
let filepath = PathBuf::new()
.join(&self.download_dir)
.join(format!("snapshot-{}", snapshot.digest));

while let Some(url) = snapshot.locations.iter().next() {
match self.http_client.download(url, &filepath).await {
Ok(()) => return Ok(filepath.display().to_string()),
Err(e) => {
warn!("Failed downloading snapshot from '{url}'.");
}
};
}

let locations = snapshot.locations.join(", ");

Err(SnapshotClientError::NoWorkingLocation {
digest: snapshot.digest.clone(),
locations,
}
.into())
}
}
2 changes: 0 additions & 2 deletions mithril-client/src/commands/download.rs
Expand Up @@ -6,8 +6,6 @@ use mithril_common::{api_version::APIVersionProvider, StdError};
use serde::Serialize;
use slog_scope::debug;

use crate::{AggregatorHTTPClient, Config, Runtime};

/// Download a snapshot.
#[derive(Parser, Debug, Clone)]
pub struct DownloadCommand {
Expand Down
7 changes: 1 addition & 6 deletions mithril-client/src/lib.rs
Expand Up @@ -11,18 +11,13 @@
//! [Digester](mithril_common::digesters::ImmutableDigester)
//! implementations using the `with_xxx` methods.

mod aggregator;
pub mod aggregator_client;
pub mod commands;
mod entities;
mod message_adapters;
mod runtime;
pub mod services;

pub use aggregator::{AggregatorHTTPClient, AggregatorHandler, AggregatorHandlerError};
pub use entities::Config;
pub use message_adapters::{
FromCertificateMessageAdapter, FromSnapshotListMessageAdapter, FromSnapshotMessageAdapter,
};
pub use runtime::{Runtime, RuntimeError};

pub use runtime::convert_to_field_items;
7 changes: 0 additions & 7 deletions mithril-client/src/runtime.rs
Expand Up @@ -3,7 +3,6 @@ use std::str;
use std::sync::Arc;
use thiserror::Error;

use crate::aggregator::{AggregatorHandler, AggregatorHandlerError};
use crate::entities::*;

use mithril_common::certificate_chain::{
Expand All @@ -24,11 +23,6 @@ pub enum RuntimeError {
#[error("an input is invalid: '{0}'")]
InvalidInput(String),

/// Error raised when an [AggregatorHandlerError] is caught when querying the aggregator using
/// a [AggregatorHandler].
#[error("aggregator handler error: '{0}'")]
AggregatorHandler(#[from] AggregatorHandlerError),

/// Error raised when a CertificateRetrieverError tries to retrieve a
/// [certificate](mithril_common::entities::Certificate)
#[error("certificate retriever error: '{0}'")]
Expand Down Expand Up @@ -192,7 +186,6 @@ mod tests {
use mithril_common::crypto_helper::{ProtocolGenesisSigner, ProtocolGenesisVerifier};
use mockall::mock;

use crate::aggregator::AggregatorHandlerError;
use mithril_common::certificate_chain::{
CertificateRetriever, CertificateRetrieverError, CertificateVerifierError,
};
Expand Down

0 comments on commit 7840930

Please sign in to comment.