Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

2 changes: 2 additions & 0 deletions post-compute/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,8 @@ serde = "1.0.219"
serde_json = "1.0.140"
sha256 = "1.6.0"
sha3 = "0.10.8"
strum = "0.27.2"
strum_macros = "0.27.2"
thiserror = "2.0.12"
walkdir = "2.5.0"
zip = "4.0.0"
Expand Down
148 changes: 51 additions & 97 deletions post-compute/src/api/worker_api.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,42 +5,6 @@ use crate::compute::{
};
use log::error;
use reqwest::{blocking::Client, header::AUTHORIZATION};
use serde::Serialize;

/// Represents payload that can be sent to the worker API to report the outcome of the
/// post‑compute stage.
///
/// The JSON structure expected by the REST endpoint is:
/// ```json
/// {
/// "cause": "<ReplicateStatusCause as string>"
/// }
/// ```
///
/// # Arguments
///
/// * `cause` - A reference to the ReplicateStatusCause indicating why the post-compute operation exited
///
/// # Example
///
/// ```rust
/// use tee_worker_post_compute::{
/// api::worker_api::ExitMessage,
/// compute::errors::ReplicateStatusCause,
/// };
///
/// let exit_message = ExitMessage::from(&ReplicateStatusCause::PostComputeInvalidTeeSignature);
/// ```
#[derive(Serialize, Debug)]
pub struct ExitMessage<'a> {
pub cause: &'a ReplicateStatusCause,
}

impl<'a> From<&'a ReplicateStatusCause> for ExitMessage<'a> {
fn from(cause: &'a ReplicateStatusCause) -> Self {
Self { cause }
}
}

/// Thin wrapper around a [`Client`] that knows how to reach the iExec worker API.
///
Expand Down Expand Up @@ -96,21 +60,21 @@ impl WorkerApiClient {
Self::new(&base_url)
}

/// Sends an exit cause for a post-compute operation to the Worker API.
/// Sends exit causes for a post-compute operation to the Worker API.
///
/// This method reports the exit cause of a post-compute operation to the Worker API,
/// This method reports the exit causes of a post-compute operation to the Worker API,
/// which can be used for tracking and debugging purposes.
///
/// # Arguments
///
/// * `authorization` - The authorization token to use for the API request
/// * `chain_task_id` - The chain task ID for which to report the exit cause
/// * `exit_cause` - The exit cause to report
/// * `chain_task_id` - The chain task ID for which to report the exit causes
/// * `exit_causes` - The exit causes to report
///
/// # Returns
///
/// * `Ok(())` - If the exit cause was successfully reported
/// * `Err(ReplicateStatusCause)` - If the exit cause could not be reported due to an HTTP error
/// * `Ok(())` - If the exit causes were successfully reported
/// * `Err(ReplicateStatusCause)` - If the exit causes could not be reported due to an HTTP error
///
/// # Errors
///
Expand All @@ -121,34 +85,34 @@ impl WorkerApiClient {
///
/// ```rust
/// use tee_worker_post_compute::{
/// api::worker_api::{ExitMessage, WorkerApiClient},
/// api::worker_api::WorkerApiClient,
/// compute::errors::ReplicateStatusCause,
/// };
///
/// let client = WorkerApiClient::new("http://worker:13100");
/// let exit_message = ExitMessage::from(&ReplicateStatusCause::PostComputeInvalidTeeSignature);
/// let exit_causes = vec![ReplicateStatusCause::PostComputeInvalidTeeSignature];
///
/// match client.send_exit_cause_for_post_compute_stage(
/// match client.send_exit_causes_for_post_compute_stage(
/// "authorization_token",
/// "0x123456789abcdef",
/// &exit_message,
/// &exit_causes,
/// ) {
/// Ok(()) => println!("Exit cause reported successfully"),
/// Err(error) => eprintln!("Failed to report exit cause: {}", error),
/// Ok(()) => println!("Exit causes reported successfully"),
/// Err(error) => eprintln!("Failed to report exit causes: {}", error),
/// }
/// ```
pub fn send_exit_cause_for_post_compute_stage(
pub fn send_exit_causes_for_post_compute_stage(
&self,
authorization: &str,
chain_task_id: &str,
exit_cause: &ExitMessage,
exit_causes: &[ReplicateStatusCause],
) -> Result<(), ReplicateStatusCause> {
let url = format!("{}/compute/post/{chain_task_id}/exit", self.base_url);
let url = format!("{}/compute/post/{chain_task_id}/exit-causes", self.base_url);
match self
.client
.post(&url)
.header(AUTHORIZATION, authorization)
.json(exit_cause)
.json(exit_causes)
.send()
{
Ok(response) => {
Expand All @@ -158,13 +122,13 @@ impl WorkerApiClient {
let status = response.status();
let body = response.text().unwrap_or_default();
error!(
"Failed to send exit cause to worker: [status:{status:?}, body:{body:#?}]"
"Failed to send exit causes to worker: [status:{status:?}, body:{body:#?}]"
);
Err(ReplicateStatusCause::PostComputeFailedUnknownIssue)
}
}
Err(e) => {
error!("An error occured while sending exit cause to worker: {e}");
error!("An error occured while sending exit causes to worker: {e}");
Err(ReplicateStatusCause::PostComputeFailedUnknownIssue)
}
}
Expand Down Expand Up @@ -266,36 +230,30 @@ mod tests {
const CHALLENGE: &str = "challenge";
const CHAIN_TASK_ID: &str = "0x123456789abcdef";

// region ExitMessage()
// region serialize List of ReplicateStatusCause
#[test]
fn should_serialize_exit_message() {
let causes = [
(
ReplicateStatusCause::PostComputeInvalidTeeSignature,
"POST_COMPUTE_INVALID_TEE_SIGNATURE",
),
(
ReplicateStatusCause::PostComputeWorkerAddressMissing,
"POST_COMPUTE_WORKER_ADDRESS_MISSING",
),
(
ReplicateStatusCause::PostComputeFailedUnknownIssue,
"POST_COMPUTE_FAILED_UNKNOWN_ISSUE",
),
fn replicate_status_cause_serializes_as_json_array_when_multiple_causes() {
let causes = vec![
ReplicateStatusCause::PostComputeInvalidTeeSignature,
ReplicateStatusCause::PostComputeWorkerAddressMissing,
];
let serialized = to_string(&causes).expect("Failed to serialize");
let expected = r#"[{"cause":"POST_COMPUTE_INVALID_TEE_SIGNATURE","message":"Invalid TEE signature"},{"cause":"POST_COMPUTE_WORKER_ADDRESS_MISSING","message":"Worker address not found in TEE session"}]"#;
assert_eq!(serialized, expected);
}

for (cause, message) in causes {
let exit_message = ExitMessage::from(&cause);
let serialized = to_string(&exit_message).expect("Failed to serialize");
let expected = format!("{{\"cause\":\"{message}\"}}");
assert_eq!(serialized, expected);
}
#[test]
fn replicate_status_cause_serializes_as_json_array_when_single_cause() {
let causes = vec![ReplicateStatusCause::PostComputeFailedUnknownIssue];
let serialized = to_string(&causes).expect("Failed to serialize");
let expected = r#"[{"cause":"POST_COMPUTE_FAILED_UNKNOWN_ISSUE","message":"Unexpected error occurred"}]"#;
assert_eq!(serialized, expected);
}
// endregion

// region get_worker_api_client
#[test]
fn should_get_worker_api_client_with_env_var() {
fn from_env_creates_client_with_custom_url_when_env_var_set() {
with_vars(
vec![(WorkerHostEnvVar.name(), Some("custom-worker-host:9999"))],
|| {
Expand All @@ -306,26 +264,24 @@ mod tests {
}

#[test]
fn should_get_worker_api_client_without_env_var() {
fn from_env_creates_client_with_default_url_when_env_var_missing() {
with_vars(vec![(WorkerHostEnvVar.name(), None::<&str>)], || {
let client = WorkerApiClient::from_env();
assert_eq!(client.base_url, format!("http://{DEFAULT_WORKER_HOST}"));
});
}
// endregion

// region send_exit_cause_for_post_compute_stage()
// region send_exit_causes_for_post_compute_stage()
#[tokio::test]
async fn should_send_exit_cause() {
async fn send_exit_causes_for_post_compute_stage_succeeds_when_server_responds_ok() {
let mock_server = MockServer::start().await;
let server_url = mock_server.uri();

let expected_body = json!({
"cause": ReplicateStatusCause::PostComputeInvalidTeeSignature,
});
let expected_body = json!([ReplicateStatusCause::PostComputeInvalidTeeSignature,]);

Mock::given(method("POST"))
.and(path(format!("/compute/post/{CHAIN_TASK_ID}/exit")))
.and(path(format!("/compute/post/{CHAIN_TASK_ID}/exit-causes")))
.and(header("Authorization", CHALLENGE))
.and(body_json(&expected_body))
.respond_with(ResponseTemplate::new(200))
Expand All @@ -334,13 +290,12 @@ mod tests {
.await;

let result = tokio::task::spawn_blocking(move || {
let exit_message =
ExitMessage::from(&ReplicateStatusCause::PostComputeInvalidTeeSignature);
let exit_causes = vec![ReplicateStatusCause::PostComputeInvalidTeeSignature];
let worker_api_client = WorkerApiClient::new(&server_url);
worker_api_client.send_exit_cause_for_post_compute_stage(
worker_api_client.send_exit_causes_for_post_compute_stage(
CHALLENGE,
CHAIN_TASK_ID,
&exit_message,
&exit_causes,
)
})
.await
Expand All @@ -351,7 +306,7 @@ mod tests {

#[tokio::test]
#[serial]
async fn should_not_send_exit_cause() {
async fn send_exit_causes_for_post_compute_stage_fails_when_server_returns_404() {
{
let mut logger = TEST_LOGGER.lock().unwrap();
while logger.pop().is_some() {}
Expand All @@ -360,20 +315,19 @@ mod tests {
let server_url = mock_server.uri();

Mock::given(method("POST"))
.and(path(format!("/compute/post/{CHAIN_TASK_ID}/exit")))
.and(path(format!("/compute/post/{CHAIN_TASK_ID}/exit-causes")))
.respond_with(ResponseTemplate::new(404))
.expect(1)
.mount(&mock_server)
.await;

let result = tokio::task::spawn_blocking(move || {
let exit_message =
ExitMessage::from(&ReplicateStatusCause::PostComputeFailedUnknownIssue);
let exit_causes = vec![ReplicateStatusCause::PostComputeFailedUnknownIssue];
let worker_api_client = WorkerApiClient::new(&server_url);
worker_api_client.send_exit_cause_for_post_compute_stage(
worker_api_client.send_exit_causes_for_post_compute_stage(
CHALLENGE,
CHAIN_TASK_ID,
&exit_message,
&exit_causes,
)
})
.await
Expand Down Expand Up @@ -402,7 +356,7 @@ mod tests {

// region send_computed_file_to_host()
#[tokio::test]
async fn should_send_computed_file_successfully() {
async fn send_computed_file_to_host_succeeds_when_server_responds_ok() {
let mock_server = MockServer::start().await;
let server_uri = mock_server.uri();

Expand Down Expand Up @@ -437,7 +391,7 @@ mod tests {

#[tokio::test]
#[serial]
async fn should_fail_send_computed_file_on_server_error() {
async fn send_computed_file_to_host_fails_when_server_returns_500() {
{
let mut logger = TEST_LOGGER.lock().unwrap();
while logger.pop().is_some() {}
Expand Down Expand Up @@ -491,7 +445,7 @@ mod tests {

#[tokio::test]
#[serial]
async fn should_handle_invalid_chain_task_id_in_url() {
async fn send_computed_file_to_host_fails_when_chain_task_id_invalid() {
{
let mut logger = TEST_LOGGER.lock().unwrap();
while logger.pop().is_some() {}
Expand Down Expand Up @@ -532,7 +486,7 @@ mod tests {
}

#[tokio::test]
async fn should_send_computed_file_with_minimal_data() {
async fn send_computed_file_to_host_succeeds_when_minimal_data_provided() {
let mock_server = MockServer::start().await;
let server_uri = mock_server.uri();

Expand Down
Loading