diff --git a/Cargo.lock b/Cargo.lock index ad17538..56fe159 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -4166,6 +4166,8 @@ dependencies = [ "serial_test", "sha256", "sha3", + "strum", + "strum_macros", "temp-env", "tempfile", "thiserror", diff --git a/post-compute/Cargo.toml b/post-compute/Cargo.toml index b5a4c23..4d1bf7e 100644 --- a/post-compute/Cargo.toml +++ b/post-compute/Cargo.toml @@ -22,6 +22,8 @@ serde = "1.0.219" serde_json = "1.0.140" sha256 = "1.6.0" sha3 = "0.10.8" +strum = "0.27.2" +strum_macros = "0.27.2" thiserror = "2.0.12" walkdir = "2.5.0" zip = "4.0.0" diff --git a/post-compute/src/api/worker_api.rs b/post-compute/src/api/worker_api.rs index ff3943d..265bd03 100644 --- a/post-compute/src/api/worker_api.rs +++ b/post-compute/src/api/worker_api.rs @@ -5,42 +5,6 @@ use crate::compute::{ }; use log::error; use reqwest::{blocking::Client, header::AUTHORIZATION}; -use serde::Serialize; - -/// Represents payload that can be sent to the worker API to report the outcome of the -/// post‑compute stage. -/// -/// The JSON structure expected by the REST endpoint is: -/// ```json -/// { -/// "cause": "" -/// } -/// ``` -/// -/// # Arguments -/// -/// * `cause` - A reference to the ReplicateStatusCause indicating why the post-compute operation exited -/// -/// # Example -/// -/// ```rust -/// use tee_worker_post_compute::{ -/// api::worker_api::ExitMessage, -/// compute::errors::ReplicateStatusCause, -/// }; -/// -/// let exit_message = ExitMessage::from(&ReplicateStatusCause::PostComputeInvalidTeeSignature); -/// ``` -#[derive(Serialize, Debug)] -pub struct ExitMessage<'a> { - pub cause: &'a ReplicateStatusCause, -} - -impl<'a> From<&'a ReplicateStatusCause> for ExitMessage<'a> { - fn from(cause: &'a ReplicateStatusCause) -> Self { - Self { cause } - } -} /// Thin wrapper around a [`Client`] that knows how to reach the iExec worker API. /// @@ -96,21 +60,21 @@ impl WorkerApiClient { Self::new(&base_url) } - /// Sends an exit cause for a post-compute operation to the Worker API. + /// Sends exit causes for a post-compute operation to the Worker API. /// - /// This method reports the exit cause of a post-compute operation to the Worker API, + /// This method reports the exit causes of a post-compute operation to the Worker API, /// which can be used for tracking and debugging purposes. /// /// # Arguments /// /// * `authorization` - The authorization token to use for the API request - /// * `chain_task_id` - The chain task ID for which to report the exit cause - /// * `exit_cause` - The exit cause to report + /// * `chain_task_id` - The chain task ID for which to report the exit causes + /// * `exit_causes` - The exit causes to report /// /// # Returns /// - /// * `Ok(())` - If the exit cause was successfully reported - /// * `Err(ReplicateStatusCause)` - If the exit cause could not be reported due to an HTTP error + /// * `Ok(())` - If the exit causes were successfully reported + /// * `Err(ReplicateStatusCause)` - If the exit causes could not be reported due to an HTTP error /// /// # Errors /// @@ -121,34 +85,34 @@ impl WorkerApiClient { /// /// ```rust /// use tee_worker_post_compute::{ - /// api::worker_api::{ExitMessage, WorkerApiClient}, + /// api::worker_api::WorkerApiClient, /// compute::errors::ReplicateStatusCause, /// }; /// /// let client = WorkerApiClient::new("http://worker:13100"); - /// let exit_message = ExitMessage::from(&ReplicateStatusCause::PostComputeInvalidTeeSignature); + /// let exit_causes = vec![ReplicateStatusCause::PostComputeInvalidTeeSignature]; /// - /// match client.send_exit_cause_for_post_compute_stage( + /// match client.send_exit_causes_for_post_compute_stage( /// "authorization_token", /// "0x123456789abcdef", - /// &exit_message, + /// &exit_causes, /// ) { - /// Ok(()) => println!("Exit cause reported successfully"), - /// Err(error) => eprintln!("Failed to report exit cause: {}", error), + /// Ok(()) => println!("Exit causes reported successfully"), + /// Err(error) => eprintln!("Failed to report exit causes: {}", error), /// } /// ``` - pub fn send_exit_cause_for_post_compute_stage( + pub fn send_exit_causes_for_post_compute_stage( &self, authorization: &str, chain_task_id: &str, - exit_cause: &ExitMessage, + exit_causes: &[ReplicateStatusCause], ) -> Result<(), ReplicateStatusCause> { - let url = format!("{}/compute/post/{chain_task_id}/exit", self.base_url); + let url = format!("{}/compute/post/{chain_task_id}/exit-causes", self.base_url); match self .client .post(&url) .header(AUTHORIZATION, authorization) - .json(exit_cause) + .json(exit_causes) .send() { Ok(response) => { @@ -158,13 +122,13 @@ impl WorkerApiClient { let status = response.status(); let body = response.text().unwrap_or_default(); error!( - "Failed to send exit cause to worker: [status:{status:?}, body:{body:#?}]" + "Failed to send exit causes to worker: [status:{status:?}, body:{body:#?}]" ); Err(ReplicateStatusCause::PostComputeFailedUnknownIssue) } } Err(e) => { - error!("An error occured while sending exit cause to worker: {e}"); + error!("An error occured while sending exit causes to worker: {e}"); Err(ReplicateStatusCause::PostComputeFailedUnknownIssue) } } @@ -266,36 +230,30 @@ mod tests { const CHALLENGE: &str = "challenge"; const CHAIN_TASK_ID: &str = "0x123456789abcdef"; - // region ExitMessage() + // region serialize List of ReplicateStatusCause #[test] - fn should_serialize_exit_message() { - let causes = [ - ( - ReplicateStatusCause::PostComputeInvalidTeeSignature, - "POST_COMPUTE_INVALID_TEE_SIGNATURE", - ), - ( - ReplicateStatusCause::PostComputeWorkerAddressMissing, - "POST_COMPUTE_WORKER_ADDRESS_MISSING", - ), - ( - ReplicateStatusCause::PostComputeFailedUnknownIssue, - "POST_COMPUTE_FAILED_UNKNOWN_ISSUE", - ), + fn replicate_status_cause_serializes_as_json_array_when_multiple_causes() { + let causes = vec![ + ReplicateStatusCause::PostComputeInvalidTeeSignature, + ReplicateStatusCause::PostComputeWorkerAddressMissing, ]; + let serialized = to_string(&causes).expect("Failed to serialize"); + let expected = r#"[{"cause":"POST_COMPUTE_INVALID_TEE_SIGNATURE","message":"Invalid TEE signature"},{"cause":"POST_COMPUTE_WORKER_ADDRESS_MISSING","message":"Worker address not found in TEE session"}]"#; + assert_eq!(serialized, expected); + } - for (cause, message) in causes { - let exit_message = ExitMessage::from(&cause); - let serialized = to_string(&exit_message).expect("Failed to serialize"); - let expected = format!("{{\"cause\":\"{message}\"}}"); - assert_eq!(serialized, expected); - } + #[test] + fn replicate_status_cause_serializes_as_json_array_when_single_cause() { + let causes = vec![ReplicateStatusCause::PostComputeFailedUnknownIssue]; + let serialized = to_string(&causes).expect("Failed to serialize"); + let expected = r#"[{"cause":"POST_COMPUTE_FAILED_UNKNOWN_ISSUE","message":"Unexpected error occurred"}]"#; + assert_eq!(serialized, expected); } // endregion // region get_worker_api_client #[test] - fn should_get_worker_api_client_with_env_var() { + fn from_env_creates_client_with_custom_url_when_env_var_set() { with_vars( vec![(WorkerHostEnvVar.name(), Some("custom-worker-host:9999"))], || { @@ -306,7 +264,7 @@ mod tests { } #[test] - fn should_get_worker_api_client_without_env_var() { + fn from_env_creates_client_with_default_url_when_env_var_missing() { with_vars(vec![(WorkerHostEnvVar.name(), None::<&str>)], || { let client = WorkerApiClient::from_env(); assert_eq!(client.base_url, format!("http://{DEFAULT_WORKER_HOST}")); @@ -314,18 +272,16 @@ mod tests { } // endregion - // region send_exit_cause_for_post_compute_stage() + // region send_exit_causes_for_post_compute_stage() #[tokio::test] - async fn should_send_exit_cause() { + async fn send_exit_causes_for_post_compute_stage_succeeds_when_server_responds_ok() { let mock_server = MockServer::start().await; let server_url = mock_server.uri(); - let expected_body = json!({ - "cause": ReplicateStatusCause::PostComputeInvalidTeeSignature, - }); + let expected_body = json!([ReplicateStatusCause::PostComputeInvalidTeeSignature,]); Mock::given(method("POST")) - .and(path(format!("/compute/post/{CHAIN_TASK_ID}/exit"))) + .and(path(format!("/compute/post/{CHAIN_TASK_ID}/exit-causes"))) .and(header("Authorization", CHALLENGE)) .and(body_json(&expected_body)) .respond_with(ResponseTemplate::new(200)) @@ -334,13 +290,12 @@ mod tests { .await; let result = tokio::task::spawn_blocking(move || { - let exit_message = - ExitMessage::from(&ReplicateStatusCause::PostComputeInvalidTeeSignature); + let exit_causes = vec![ReplicateStatusCause::PostComputeInvalidTeeSignature]; let worker_api_client = WorkerApiClient::new(&server_url); - worker_api_client.send_exit_cause_for_post_compute_stage( + worker_api_client.send_exit_causes_for_post_compute_stage( CHALLENGE, CHAIN_TASK_ID, - &exit_message, + &exit_causes, ) }) .await @@ -351,7 +306,7 @@ mod tests { #[tokio::test] #[serial] - async fn should_not_send_exit_cause() { + async fn send_exit_causes_for_post_compute_stage_fails_when_server_returns_404() { { let mut logger = TEST_LOGGER.lock().unwrap(); while logger.pop().is_some() {} @@ -360,20 +315,19 @@ mod tests { let server_url = mock_server.uri(); Mock::given(method("POST")) - .and(path(format!("/compute/post/{CHAIN_TASK_ID}/exit"))) + .and(path(format!("/compute/post/{CHAIN_TASK_ID}/exit-causes"))) .respond_with(ResponseTemplate::new(404)) .expect(1) .mount(&mock_server) .await; let result = tokio::task::spawn_blocking(move || { - let exit_message = - ExitMessage::from(&ReplicateStatusCause::PostComputeFailedUnknownIssue); + let exit_causes = vec![ReplicateStatusCause::PostComputeFailedUnknownIssue]; let worker_api_client = WorkerApiClient::new(&server_url); - worker_api_client.send_exit_cause_for_post_compute_stage( + worker_api_client.send_exit_causes_for_post_compute_stage( CHALLENGE, CHAIN_TASK_ID, - &exit_message, + &exit_causes, ) }) .await @@ -402,7 +356,7 @@ mod tests { // region send_computed_file_to_host() #[tokio::test] - async fn should_send_computed_file_successfully() { + async fn send_computed_file_to_host_succeeds_when_server_responds_ok() { let mock_server = MockServer::start().await; let server_uri = mock_server.uri(); @@ -437,7 +391,7 @@ mod tests { #[tokio::test] #[serial] - async fn should_fail_send_computed_file_on_server_error() { + async fn send_computed_file_to_host_fails_when_server_returns_500() { { let mut logger = TEST_LOGGER.lock().unwrap(); while logger.pop().is_some() {} @@ -491,7 +445,7 @@ mod tests { #[tokio::test] #[serial] - async fn should_handle_invalid_chain_task_id_in_url() { + async fn send_computed_file_to_host_fails_when_chain_task_id_invalid() { { let mut logger = TEST_LOGGER.lock().unwrap(); while logger.pop().is_some() {} @@ -532,7 +486,7 @@ mod tests { } #[tokio::test] - async fn should_send_computed_file_with_minimal_data() { + async fn send_computed_file_to_host_succeeds_when_minimal_data_provided() { let mock_server = MockServer::start().await; let server_uri = mock_server.uri(); diff --git a/post-compute/src/compute/app_runner.rs b/post-compute/src/compute/app_runner.rs index da70cec..7199a0e 100644 --- a/post-compute/src/compute/app_runner.rs +++ b/post-compute/src/compute/app_runner.rs @@ -1,4 +1,4 @@ -use crate::api::worker_api::{ExitMessage, WorkerApiClient}; +use crate::api::worker_api::WorkerApiClient; use crate::compute::{ computed_file::{ ComputedFile, build_result_digest_in_computed_file, read_computed_file, sign_computed_file, @@ -31,11 +31,11 @@ pub enum ExitMode { pub trait PostComputeRunnerInterface { fn run_post_compute(&self, chain_task_id: &str) -> Result<(), ReplicateStatusCause>; fn get_challenge(&self, chain_task_id: &str) -> Result; - fn send_exit_cause( + fn send_exit_causes( &self, authorization: &str, chain_task_id: &str, - exit_message: &ExitMessage, + exit_causes: &[ReplicateStatusCause], ) -> Result<(), ReplicateStatusCause>; fn send_computed_file(&self, computed_file: &ComputedFile) -> Result<(), ReplicateStatusCause>; } @@ -99,14 +99,14 @@ impl PostComputeRunnerInterface for DefaultPostComputeRunner { get_challenge(chain_task_id) } - fn send_exit_cause( + fn send_exit_causes( &self, authorization: &str, chain_task_id: &str, - exit_message: &ExitMessage, + exit_causes: &[ReplicateStatusCause], ) -> Result<(), ReplicateStatusCause> { self.worker_api_client - .send_exit_cause_for_post_compute_stage(authorization, chain_task_id, exit_message) + .send_exit_causes_for_post_compute_stage(authorization, chain_task_id, exit_causes) } fn send_computed_file(&self, computed_file: &ComputedFile) -> Result<(), ReplicateStatusCause> { @@ -189,12 +189,12 @@ pub fn start_with_runner(runner: &R) -> ExitMode } }; - let exit_message = ExitMessage::from(&exit_cause); + let exit_causes = vec![exit_cause.clone()]; - match runner.send_exit_cause(&authorization, &chain_task_id, &exit_message) { + match runner.send_exit_causes(&authorization, &chain_task_id, &exit_causes) { Ok(()) => ExitMode::ReportedFailure, Err(_) => { - error!("Failed to report exit cause [exitCause:{exit_cause}]"); + error!("Failed to report exit causes [exitCauses:{exit_causes:?}]"); ExitMode::UnreportedFailure } } @@ -273,7 +273,7 @@ mod tests { self } - fn with_send_exit_cause_failure(mut self) -> Self { + fn with_send_exit_causes_failure(mut self) -> Self { self.send_exit_cause_success = false; self } @@ -298,11 +298,11 @@ mod tests { } } - fn send_exit_cause( + fn send_exit_causes( &self, _authorization: &str, _chain_task_id: &str, - _exit_message: &ExitMessage, + _exit_causes: &[ReplicateStatusCause], ) -> Result<(), ReplicateStatusCause> { if self.send_exit_cause_success { Ok(()) @@ -423,7 +423,7 @@ mod tests { } #[test] - fn start_return_2_when_exit_cause_not_transmitted() { + fn start_return_2_when_exit_causes_not_transmitted() { with_vars( vec![( TeeSessionEnvironmentVariable::IexecTaskId.name(), @@ -434,7 +434,7 @@ mod tests { .with_run_post_compute_failure(Some( ReplicateStatusCause::PostComputeInvalidTeeSignature, )) - .with_send_exit_cause_failure(); + .with_send_exit_causes_failure(); let result = start_with_runner(&runner); assert_eq!( diff --git a/post-compute/src/compute/errors.rs b/post-compute/src/compute/errors.rs index b6bfcdd..8c6c228 100644 --- a/post-compute/src/compute/errors.rs +++ b/post-compute/src/compute/errors.rs @@ -1,9 +1,10 @@ -use serde::{Deserialize, Serialize}; +use serde::{Serializer, ser::SerializeStruct}; +use strum_macros::EnumDiscriminants; use thiserror::Error; -#[derive(Debug, PartialEq, Clone, Error, Serialize, Deserialize)] -#[serde(rename_all(serialize = "SCREAMING_SNAKE_CASE"))] -#[allow(clippy::enum_variant_names)] +#[derive(Debug, PartialEq, Clone, Error, EnumDiscriminants)] +#[strum_discriminants(derive(serde::Serialize))] +#[strum_discriminants(serde(rename_all = "SCREAMING_SNAKE_CASE"))] pub enum ReplicateStatusCause { #[error("computed.json file missing")] PostComputeComputedFileNotFound, @@ -11,7 +12,7 @@ pub enum ReplicateStatusCause { PostComputeDropboxUploadFailed, #[error("Encryption stage failed")] PostComputeEncryptionFailed, - #[error("Encryption public key related environment variable is missing")] + #[error("Encryption public key not found in TEE session")] PostComputeEncryptionPublicKeyMissing, #[error("Unexpected error occurred")] PostComputeFailedUnknownIssue, @@ -29,14 +30,74 @@ pub enum ReplicateStatusCause { PostComputeResultFileNotFound, #[error("Failed to send computed file")] PostComputeSendComputedFileFailed, - #[error("Storage token related environment variable is missing")] + #[error("Storage token not found in TEE session")] PostComputeStorageTokenMissing, - #[error("Task ID related environment variable is missing")] + #[error("Task ID not found in TEE session")] PostComputeTaskIdMissing, - #[error("Tee challenge private key related environment variable is missing")] + #[error("TEE challenge private key not found in TEE session")] PostComputeTeeChallengePrivateKeyMissing, #[error("Result file name too long")] PostComputeTooLongResultFileName, - #[error("Worker address related environment variable is missing")] + #[error("Worker address not found in TEE session")] PostComputeWorkerAddressMissing, } + +impl serde::Serialize for ReplicateStatusCause { + fn serialize(&self, serializer: S) -> Result + where + S: Serializer, + { + let mut state = serializer.serialize_struct("ReplicateStatusCause", 2)?; + state.serialize_field("cause", &ReplicateStatusCauseDiscriminants::from(self))?; + state.serialize_field("message", &self.to_string())?; + state.end() + } +} + +#[cfg(test)] +mod tests { + use super::*; + use serde_json::{json, to_value}; + + #[test] + fn error_variant_serialize_correctly() { + let expected = json!({ + "cause": "POST_COMPUTE_TOO_LONG_RESULT_FILE_NAME", + "message": "Result file name too long" + }); + let error_variant = ReplicateStatusCause::PostComputeTooLongResultFileName; + assert_eq!(to_value(&error_variant).unwrap(), expected); + } + + #[test] + fn error_list_serializes_as_json_array() { + let errors = vec![ + ReplicateStatusCause::PostComputeComputedFileNotFound, + ReplicateStatusCause::PostComputeInvalidTeeSignature, + ReplicateStatusCause::PostComputeTaskIdMissing, + ]; + let serialized = to_value(&errors).unwrap(); + let expected = json!([ + { + "cause": "POST_COMPUTE_COMPUTED_FILE_NOT_FOUND", + "message": "computed.json file missing" + }, + { + "cause": "POST_COMPUTE_INVALID_TEE_SIGNATURE", + "message": "Invalid TEE signature" + }, + { + "cause": "POST_COMPUTE_TASK_ID_MISSING", + "message": "Task ID not found in TEE session" + } + ]); + assert_eq!(serialized, expected); + } + + #[test] + fn empty_error_list_serializes_as_empty_json_array() { + let errors: Vec = vec![]; + let serialized = to_value(&errors).unwrap(); + assert_eq!(serialized, json!([])); + } +} diff --git a/post-compute/src/compute/utils/hash_utils.rs b/post-compute/src/compute/utils/hash_utils.rs index 6916772..54d8f01 100644 --- a/post-compute/src/compute/utils/hash_utils.rs +++ b/post-compute/src/compute/utils/hash_utils.rs @@ -17,7 +17,7 @@ pub fn hex_string_to_byte_array(input: &str) -> Vec { } let mut data: Vec = vec![]; - let start_idx = if len % 2 != 0 { + let start_idx = if !len.is_multiple_of(2) { let byte = u8::from_str_radix(&clean_input[0..1], 16).expect("Invalid hex digit in input string"); data.push(byte);