Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions keylime-agent/src/main.rs
Original file line number Diff line number Diff line change
Expand Up @@ -627,6 +627,7 @@ async fn main() -> Result<()> {
attest,
signature,
ak_handle,
retry_config: None,
};
match keylime::agent_registration::register_agent(aa, &mut ctx).await {
Ok(()) => (),
Expand Down
5 changes: 2 additions & 3 deletions keylime-push-model-agent/src/attestation.rs
Original file line number Diff line number Diff line change
Expand Up @@ -80,10 +80,9 @@ impl AttestationClient {
let req = filler.get_attestation_request();
debug!("Request body: {:?}", serde_json::to_string(&req));

// --- Now using send_json, which has the retry logic ---
let response = self
.client
.send_json(reqwest::Method::POST, config.url, &req)?
.get_json_request(reqwest::Method::POST, config.url, &req)?
.send()
.await?;

Expand Down Expand Up @@ -114,7 +113,7 @@ impl AttestationClient {

let response = self
.client
.send_json(reqwest::Method::PATCH, config.url, &json_body)?
.get_json_request(reqwest::Method::PATCH, config.url, &json_body)?
.send()
.await?;

Expand Down
30 changes: 28 additions & 2 deletions keylime-push-model-agent/src/registration.rs
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
use keylime::{
agent_registration::{AgentRegistration, AgentRegistrationConfig},
agent_registration::{
AgentRegistration, AgentRegistrationConfig, RetryConfig,
},
cert,
config::PushModelConfigTrait,
context_info,
Expand All @@ -20,6 +22,23 @@ pub async fn check_registration<T: PushModelConfigTrait>(
Ok(())
}

fn get_retry_config<T: PushModelConfigTrait>(
config: &T,
) -> Option<RetryConfig> {
if config.expbackoff_max_retries().is_none()
&& config.expbackoff_initial_delay().is_none()
&& config.expbackoff_max_delay().is_none()
{
None
} else {
Some(RetryConfig {
max_retries: config.expbackoff_max_retries().unwrap_or(0),
initial_delay_ms: config.expbackoff_initial_delay().unwrap_or(0),
max_delay_ms: *config.expbackoff_max_delay(),
})
}
}

pub async fn register_agent<T: PushModelConfigTrait>(
config: &T,
context_info: &mut context_info::ContextInfo,
Expand Down Expand Up @@ -48,6 +67,8 @@ pub async fn register_agent<T: PushModelConfigTrait>(

let server_cert_key = cert::cert_from_server_key(&cert_config)?;

let retry_config = get_retry_config(config);

let aa = AgentRegistration {
ak: context_info.ak.clone(),
ek_result: context_info.ek_result.clone(),
Expand All @@ -63,6 +84,7 @@ pub async fn register_agent<T: PushModelConfigTrait>(
attest: None, // TODO: Check how to proceed with attestation, normally, no device ID means no attest
signature: None, // TODO: Normally, no device ID means no signature
ak_handle: context_info.ak_handle,
retry_config,
};
let ctx = context_info.get_mutable_tpm_context();
match keylime::agent_registration::register_agent(aa, ctx).await {
Expand Down Expand Up @@ -95,14 +117,18 @@ mod tests {
async fn test_register_agent() {
let _mutex = testing::lock_tests().await;
let tmpdir = tempfile::tempdir().expect("failed to create tmpdir");
let config = get_testing_config(tmpdir.path());
let mut config = get_testing_config(tmpdir.path());
let alg_config = AlgorithmConfigurationString {
tpm_encryption_alg: "rsa".to_string(),
tpm_hash_alg: "sha256".to_string(),
tpm_signing_alg: "rsassa".to_string(),
agent_data_path: "".to_string(),
disabled_signing_algorithms: vec![],
};
config.expbackoff_initial_delay = None;
config.expbackoff_max_retries = None;
config.expbackoff_max_delay = None;

let mut context_info = ContextInfo::new_from_str(alg_config)
.expect("Failed to create context info from string");
let result = register_agent(&config, &mut context_info).await;
Expand Down
10 changes: 10 additions & 0 deletions keylime/src/agent_registration.rs
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,13 @@ pub struct AgentRegistrationConfig {
pub registrar_port: u32,
}

#[derive(Debug, Default, Clone)]
pub struct RetryConfig {
pub max_retries: u32,
pub initial_delay_ms: u64,
pub max_delay_ms: Option<u64>,
}

#[derive(Debug)]
pub struct AgentRegistration {
pub ak: tpm::AKResult,
Expand All @@ -37,6 +44,7 @@ pub struct AgentRegistration {
pub attest: Option<tss_esapi::structures::Attest>,
pub signature: Option<tss_esapi::structures::Signature>,
pub ak_handle: KeyHandle,
pub retry_config: Option<RetryConfig>,
}

pub async fn register_agent(
Expand Down Expand Up @@ -105,11 +113,13 @@ pub async fn register_agent(
let ai = ai_builder.build().await?;

let ac = &aa.agent_registration_config;

// Build the registrar client
// Create a RegistrarClientBuilder and set the parameters
let mut registrar_client = RegistrarClientBuilder::new()
.registrar_address(ac.registrar_ip.clone())
.registrar_port(ac.registrar_port)
.retry_config(aa.retry_config.clone())
.build()
.await?;

Expand Down
95 changes: 83 additions & 12 deletions keylime/src/registrar_client.rs
Original file line number Diff line number Diff line change
@@ -1,8 +1,14 @@
use crate::{agent_identity::AgentIdentity, serialization::*};
use crate::resilient_client::ResilientClient;
use crate::{
agent_identity::AgentIdentity, agent_registration::RetryConfig,
serialization::*,
};
use log::*;
use reqwest::StatusCode;
use serde::{Deserialize, Serialize};
use serde_json::Number;
use std::net::IpAddr;
use std::time::Duration;
use thiserror::Error;

use crate::version::KeylimeRegistrarVersion;
Expand Down Expand Up @@ -30,6 +36,10 @@ pub enum RegistrarClientBuilderError {
/// Reqwest error
#[error("Reqwest error: {0}")]
Reqwest(#[from] reqwest::Error),

/// Middleware error
#[error("Middleware error: {0}")]
Middleware(#[from] reqwest_middleware::Error),
}

#[derive(Debug, Default)]
Expand All @@ -38,6 +48,7 @@ pub struct RegistrarClientBuilder {
registrar_supported_api_versions: Option<Vec<String>>,
registrar_address: Option<String>,
registrar_port: Option<u32>,
retry_config: Option<RetryConfig>,
}

impl RegistrarClientBuilder {
Expand Down Expand Up @@ -67,6 +78,16 @@ impl RegistrarClientBuilder {
self
}

/// Set the RetryConfig for the registrar client
///
/// # Arguments:
///
/// * rt: RetryConfig: The retry configuration to use for the registrar client
pub fn retry_config(mut self, rt: Option<RetryConfig>) -> Self {
self.retry_config = rt;
self
}

/// Parse the received address
fn parse_registrar_address(address: String) -> String {
// Parse the registrar IP or hostname
Expand Down Expand Up @@ -105,11 +126,30 @@ impl RegistrarClientBuilder {

info!("Requesting registrar API version to {addr}");

let resp = reqwest::Client::new()
.get(&addr)
.send()
.await
.map_err(RegistrarClientBuilderError::Reqwest)?;
let resp = if let Some(retry_config) = &self.retry_config {
debug!(
"Using ResilientClient for version check with {} retries.",
retry_config.max_retries
);
let client = ResilientClient::new(
None,
Duration::from_millis(retry_config.initial_delay_ms),
retry_config.max_retries,
&[StatusCode::OK],
retry_config.max_delay_ms.map(Duration::from_millis),
);

client
.get_request(reqwest::Method::GET, &addr)
.send()
.await?
} else {
reqwest::Client::new()
.get(&addr)
.send()
.await
.map_err(RegistrarClientBuilderError::Reqwest)?
};

if !resp.status().is_success() {
info!("Registrar at '{addr}' does not support the '/version' endpoint");
Expand Down Expand Up @@ -153,13 +193,25 @@ impl RegistrarClientBuilder {
},
};

let resilient_client =
self.retry_config.as_ref().map(|retry_config| {
ResilientClient::new(
None,
Duration::from_millis(retry_config.initial_delay_ms),
retry_config.max_retries,
&[StatusCode::OK],
retry_config.max_delay_ms.map(Duration::from_millis),
)
});

Ok(RegistrarClient {
supported_api_versions: self
.registrar_supported_api_versions
.clone(),
api_version: registrar_api_version,
registrar_ip,
registrar_port,
resilient_client,
})
}
}
Expand Down Expand Up @@ -196,14 +248,23 @@ pub enum RegistrarClientError {
/// Reqwest error
#[error("Reqwest error: {0}")]
Reqwest(#[from] reqwest::Error),

/// Serde error
#[error("Serde error: {0}")]
Serde(#[from] serde_json::Error),

/// Middleware error
#[error("Middleware error: {0}")]
Middleware(#[from] reqwest_middleware::Error),
}

#[derive(Default, Debug)]
#[derive(Clone, Default, Debug)]
pub struct RegistrarClient {
api_version: String,
supported_api_versions: Option<Vec<String>>,
registrar_ip: String,
registrar_port: u32,
resilient_client: Option<ResilientClient>,
}

#[derive(Debug, Serialize, Deserialize)]
Expand Down Expand Up @@ -304,11 +365,21 @@ impl RegistrarClient {
&addr, &ai.uuid
);

let resp = reqwest::Client::new()
.post(&addr)
.json(&data)
.send()
.await?;
let resp = match self.resilient_client {
Some(ref client) => client
.get_json_request(reqwest::Method::POST, &addr, &data)
.map_err(RegistrarClientError::Serde)?
.send()
.await
.map_err(RegistrarClientError::Middleware)?,
None => {
reqwest::Client::new()
.post(&addr)
.json(&data)
.send()
.await?
}
};

if !resp.status().is_success() {
return Err(RegistrarClientError::Registration {
Expand Down
51 changes: 46 additions & 5 deletions keylime/src/resilient_client.rs
Original file line number Diff line number Diff line change
Expand Up @@ -74,8 +74,13 @@ impl ResilientClient {
}
}

/// Sends a non JSON request using the client.
pub fn get_request(&self, method: Method, url: &str) -> RequestBuilder {
self.client.request(method, url)
}

/// Prepares a request with a JSON body, returning a Result.
pub fn send_json<T: Serialize>(
pub fn get_json_request<T: Serialize>(
&self,
method: Method,
url: &str,
Expand Down Expand Up @@ -170,7 +175,7 @@ mod tests {
);

let response = client
.send_json(
.get_json_request(
Method::POST,
&format!("{}/submit", &mock_server.uri()),
&json!({}),
Expand Down Expand Up @@ -205,7 +210,7 @@ mod tests {
);

let response = client
.send_json(
.get_json_request(
Method::POST,
&format!("{}/submit", &mock_server.uri()),
&json!({}),
Expand Down Expand Up @@ -241,7 +246,7 @@ mod tests {
);

let response = client
.send_json(
.get_json_request(
Method::POST,
&format!("{}/submit", &mock_server.uri()),
&json!({}),
Expand Down Expand Up @@ -287,7 +292,7 @@ mod tests {
);

let response = client
.send_json(Method::GET, &unreachable_url, &json!({}))
.get_json_request(Method::GET, &unreachable_url, &json!({}))
.unwrap() //#[allow_ci]
.send()
.await;
Expand All @@ -298,4 +303,40 @@ mod tests {
"Expected the request to fail with a network error"
);
}

#[tokio::test]
async fn test_get_request_without_body() {
let mock_server = MockServer::start().await;

Mock::given(method("GET"))
.and(path("/health"))
.respond_with(ResponseTemplate::new(200).set_body_string("OK"))
.mount(&mock_server)
.await;

let client = ResilientClient::new(
None,
Duration::from_millis(10),
3,
&[StatusCode::OK],
None,
);

let response = client
.get_request(
Method::GET,
&format!("{}/health", &mock_server.uri()),
)
.send()
.await;

assert!(response.is_ok());
let res = response.unwrap(); //#[allow_ci]
assert_eq!(res.status(), StatusCode::OK);
assert_eq!(res.text().await.unwrap(), "OK"); //#[allow_ci]

let received_requests =
mock_server.received_requests().await.unwrap(); //#[allow_ci]
assert_eq!(received_requests.len(), 1);
}
}