From d8048cac45101b6de2995e34538680389f307b3f Mon Sep 17 00:00:00 2001 From: Natchica Date: Thu, 16 Oct 2025 15:44:24 +0200 Subject: [PATCH 1/6] feat: return list of errors in new format to worker --- Cargo.lock | 44 ++++- post-compute/Cargo.toml | 1 + post-compute/src/api/worker_api.rs | 96 +++------- post-compute/src/compute/app_runner.rs | 14 +- post-compute/src/compute/errors.rs | 185 ++++++++++++++++++- post-compute/src/compute/utils/hash_utils.rs | 2 +- 6 files changed, 260 insertions(+), 82 deletions(-) diff --git a/Cargo.lock b/Cargo.lock index c3f7c26..9d6f5b3 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -1700,6 +1700,12 @@ version = "0.3.31" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "f90f7dce0722e95104fcb095585910c0977252f286e354b5e3bd38902cd99988" +[[package]] +name = "futures-timer" +version = "3.0.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "f288b0a4f20f9a56b5d1da57e2227c661b7b16168e2f72365f57b63326e29b24" + [[package]] name = "futures-util" version = "0.3.31" @@ -3331,6 +3337,12 @@ version = "0.8.6" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "caf4aa5b0f434c91fe5c7f1ecb6a5ece2130b02ad2a590589dda5146df959001" +[[package]] +name = "relative-path" +version = "1.9.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "ba39f3699c378cd8970968dcbff9c43159ea4cfbd88d43c00b22f2ef10a435d2" + [[package]] name = "reqwest" version = "0.12.23" @@ -3427,6 +3439,35 @@ dependencies = [ "zeroize", ] +[[package]] +name = "rstest" +version = "0.26.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "f5a3193c063baaa2a95a33f03035c8a72b83d97a54916055ba22d35ed3839d49" +dependencies = [ + "futures-timer", + "futures-util", + "rstest_macros", +] + +[[package]] +name = "rstest_macros" +version = "0.26.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "9c845311f0ff7951c5506121a9ad75aec44d083c31583b2ea5a30bcb0b0abba0" +dependencies = [ + "cfg-if", + "glob", + "proc-macro-crate", + "proc-macro2", + "quote", + "regex", + "relative-path", + "rustc_version 0.4.1", + "syn 2.0.106", + "unicode-ident", +] + [[package]] name = "ruint" version = "1.16.0" @@ -4143,6 +4184,7 @@ dependencies = [ "rand 0.8.5", "reqwest", "rsa", + "rstest", "serde", "serde_json", "serial_test", @@ -4159,7 +4201,7 @@ dependencies = [ [[package]] name = "tee-worker-pre-compute" -version = "0.2.0" +version = "0.3.0" dependencies = [ "aes", "alloy-signer", diff --git a/post-compute/Cargo.toml b/post-compute/Cargo.toml index b5a4c23..f98031e 100644 --- a/post-compute/Cargo.toml +++ b/post-compute/Cargo.toml @@ -30,6 +30,7 @@ zip = "4.0.0" logtest = "2.0.0" mockall = "0.13.1" once_cell = "1.21.3" +rstest = "0.26.1" serial_test = "3.2.0" temp-env = "0.3.6" tempfile = "3.20.0" diff --git a/post-compute/src/api/worker_api.rs b/post-compute/src/api/worker_api.rs index ff3943d..488fa49 100644 --- a/post-compute/src/api/worker_api.rs +++ b/post-compute/src/api/worker_api.rs @@ -5,42 +5,6 @@ use crate::compute::{ }; use log::error; use reqwest::{blocking::Client, header::AUTHORIZATION}; -use serde::Serialize; - -/// Represents payload that can be sent to the worker API to report the outcome of the -/// post‑compute stage. -/// -/// The JSON structure expected by the REST endpoint is: -/// ```json -/// { -/// "cause": "" -/// } -/// ``` -/// -/// # Arguments -/// -/// * `cause` - A reference to the ReplicateStatusCause indicating why the post-compute operation exited -/// -/// # Example -/// -/// ```rust -/// use tee_worker_post_compute::{ -/// api::worker_api::ExitMessage, -/// compute::errors::ReplicateStatusCause, -/// }; -/// -/// let exit_message = ExitMessage::from(&ReplicateStatusCause::PostComputeInvalidTeeSignature); -/// ``` -#[derive(Serialize, Debug)] -pub struct ExitMessage<'a> { - pub cause: &'a ReplicateStatusCause, -} - -impl<'a> From<&'a ReplicateStatusCause> for ExitMessage<'a> { - fn from(cause: &'a ReplicateStatusCause) -> Self { - Self { cause } - } -} /// Thin wrapper around a [`Client`] that knows how to reach the iExec worker API. /// @@ -121,17 +85,17 @@ impl WorkerApiClient { /// /// ```rust /// use tee_worker_post_compute::{ - /// api::worker_api::{ExitMessage, WorkerApiClient}, + /// api::worker_api::WorkerApiClient, /// compute::errors::ReplicateStatusCause, /// }; /// /// let client = WorkerApiClient::new("http://worker:13100"); - /// let exit_message = ExitMessage::from(&ReplicateStatusCause::PostComputeInvalidTeeSignature); + /// let exit_causes = vec![ReplicateStatusCause::PostComputeInvalidTeeSignature]; /// /// match client.send_exit_cause_for_post_compute_stage( /// "authorization_token", /// "0x123456789abcdef", - /// &exit_message, + /// &exit_causes, /// ) { /// Ok(()) => println!("Exit cause reported successfully"), /// Err(error) => eprintln!("Failed to report exit cause: {}", error), @@ -141,14 +105,14 @@ impl WorkerApiClient { &self, authorization: &str, chain_task_id: &str, - exit_cause: &ExitMessage, + exit_causes: &[ReplicateStatusCause], ) -> Result<(), ReplicateStatusCause> { let url = format!("{}/compute/post/{chain_task_id}/exit", self.base_url); match self .client .post(&url) .header(AUTHORIZATION, authorization) - .json(exit_cause) + .json(exit_causes) .send() { Ok(response) => { @@ -266,30 +230,24 @@ mod tests { const CHALLENGE: &str = "challenge"; const CHAIN_TASK_ID: &str = "0x123456789abcdef"; - // region ExitMessage() + // region serialize List of ReplicateStatusCause #[test] - fn should_serialize_exit_message() { - let causes = [ - ( - ReplicateStatusCause::PostComputeInvalidTeeSignature, - "POST_COMPUTE_INVALID_TEE_SIGNATURE", - ), - ( - ReplicateStatusCause::PostComputeWorkerAddressMissing, - "POST_COMPUTE_WORKER_ADDRESS_MISSING", - ), - ( - ReplicateStatusCause::PostComputeFailedUnknownIssue, - "POST_COMPUTE_FAILED_UNKNOWN_ISSUE", - ), + fn should_serialize_list_of_exit_causes() { + let causes = vec![ + ReplicateStatusCause::PostComputeInvalidTeeSignature, + ReplicateStatusCause::PostComputeWorkerAddressMissing, ]; + let serialized = to_string(&causes).expect("Failed to serialize"); + let expected = r#"[{"cause":"POST_COMPUTE_INVALID_TEE_SIGNATURE","message":"Invalid TEE signature"},{"cause":"POST_COMPUTE_WORKER_ADDRESS_MISSING","message":"Worker address related environment variable is missing"}]"#; + assert_eq!(serialized, expected); + } - for (cause, message) in causes { - let exit_message = ExitMessage::from(&cause); - let serialized = to_string(&exit_message).expect("Failed to serialize"); - let expected = format!("{{\"cause\":\"{message}\"}}"); - assert_eq!(serialized, expected); - } + #[test] + fn should_serialize_single_exit_cause() { + let causes = vec![ReplicateStatusCause::PostComputeFailedUnknownIssue]; + let serialized = to_string(&causes).expect("Failed to serialize"); + let expected = r#"[{"cause":"POST_COMPUTE_FAILED_UNKNOWN_ISSUE","message":"Unexpected error occurred"}]"#; + assert_eq!(serialized, expected); } // endregion @@ -320,9 +278,7 @@ mod tests { let mock_server = MockServer::start().await; let server_url = mock_server.uri(); - let expected_body = json!({ - "cause": ReplicateStatusCause::PostComputeInvalidTeeSignature, - }); + let expected_body = json!([ReplicateStatusCause::PostComputeInvalidTeeSignature,]); Mock::given(method("POST")) .and(path(format!("/compute/post/{CHAIN_TASK_ID}/exit"))) @@ -334,13 +290,12 @@ mod tests { .await; let result = tokio::task::spawn_blocking(move || { - let exit_message = - ExitMessage::from(&ReplicateStatusCause::PostComputeInvalidTeeSignature); + let exit_causes = vec![ReplicateStatusCause::PostComputeInvalidTeeSignature]; let worker_api_client = WorkerApiClient::new(&server_url); worker_api_client.send_exit_cause_for_post_compute_stage( CHALLENGE, CHAIN_TASK_ID, - &exit_message, + &exit_causes, ) }) .await @@ -367,13 +322,12 @@ mod tests { .await; let result = tokio::task::spawn_blocking(move || { - let exit_message = - ExitMessage::from(&ReplicateStatusCause::PostComputeFailedUnknownIssue); + let exit_causes = vec![ReplicateStatusCause::PostComputeFailedUnknownIssue]; let worker_api_client = WorkerApiClient::new(&server_url); worker_api_client.send_exit_cause_for_post_compute_stage( CHALLENGE, CHAIN_TASK_ID, - &exit_message, + &exit_causes, ) }) .await diff --git a/post-compute/src/compute/app_runner.rs b/post-compute/src/compute/app_runner.rs index da70cec..707efc0 100644 --- a/post-compute/src/compute/app_runner.rs +++ b/post-compute/src/compute/app_runner.rs @@ -1,4 +1,4 @@ -use crate::api::worker_api::{ExitMessage, WorkerApiClient}; +use crate::api::worker_api::WorkerApiClient; use crate::compute::{ computed_file::{ ComputedFile, build_result_digest_in_computed_file, read_computed_file, sign_computed_file, @@ -35,7 +35,7 @@ pub trait PostComputeRunnerInterface { &self, authorization: &str, chain_task_id: &str, - exit_message: &ExitMessage, + exit_causes: &[ReplicateStatusCause], ) -> Result<(), ReplicateStatusCause>; fn send_computed_file(&self, computed_file: &ComputedFile) -> Result<(), ReplicateStatusCause>; } @@ -103,10 +103,10 @@ impl PostComputeRunnerInterface for DefaultPostComputeRunner { &self, authorization: &str, chain_task_id: &str, - exit_message: &ExitMessage, + exit_causes: &[ReplicateStatusCause], ) -> Result<(), ReplicateStatusCause> { self.worker_api_client - .send_exit_cause_for_post_compute_stage(authorization, chain_task_id, exit_message) + .send_exit_cause_for_post_compute_stage(authorization, chain_task_id, exit_causes) } fn send_computed_file(&self, computed_file: &ComputedFile) -> Result<(), ReplicateStatusCause> { @@ -189,9 +189,9 @@ pub fn start_with_runner(runner: &R) -> ExitMode } }; - let exit_message = ExitMessage::from(&exit_cause); + let exit_causes = vec![exit_cause.clone()]; - match runner.send_exit_cause(&authorization, &chain_task_id, &exit_message) { + match runner.send_exit_cause(&authorization, &chain_task_id, &exit_causes) { Ok(()) => ExitMode::ReportedFailure, Err(_) => { error!("Failed to report exit cause [exitCause:{exit_cause}]"); @@ -302,7 +302,7 @@ mod tests { &self, _authorization: &str, _chain_task_id: &str, - _exit_message: &ExitMessage, + _exit_causes: &[ReplicateStatusCause], ) -> Result<(), ReplicateStatusCause> { if self.send_exit_cause_success { Ok(()) diff --git a/post-compute/src/compute/errors.rs b/post-compute/src/compute/errors.rs index b6bfcdd..b3b8f7b 100644 --- a/post-compute/src/compute/errors.rs +++ b/post-compute/src/compute/errors.rs @@ -1,8 +1,7 @@ use serde::{Deserialize, Serialize}; use thiserror::Error; -#[derive(Debug, PartialEq, Clone, Error, Serialize, Deserialize)] -#[serde(rename_all(serialize = "SCREAMING_SNAKE_CASE"))] +#[derive(Debug, PartialEq, Clone, Error, Deserialize)] #[allow(clippy::enum_variant_names)] pub enum ReplicateStatusCause { #[error("computed.json file missing")] @@ -40,3 +39,185 @@ pub enum ReplicateStatusCause { #[error("Worker address related environment variable is missing")] PostComputeWorkerAddressMissing, } + +impl ReplicateStatusCause { + fn to_screaming_snake_case(&self) -> String { + let debug_str = format!("{:?}", self); + let mut result = String::new(); + let mut prev_was_lowercase = false; + + for c in debug_str.chars() { + if c.is_uppercase() && !result.is_empty() && prev_was_lowercase { + result.push('_'); + } + result.push(c.to_ascii_uppercase()); + prev_was_lowercase = c.is_lowercase(); + } + + result + } +} + +#[derive(Debug, Serialize)] +pub struct WorkflowError { + pub cause: String, + pub message: String, +} + +impl From<&ReplicateStatusCause> for WorkflowError { + fn from(cause: &ReplicateStatusCause) -> Self { + WorkflowError { + cause: cause.to_screaming_snake_case(), + message: cause.to_string(), + } + } +} + +impl Serialize for ReplicateStatusCause { + fn serialize(&self, serializer: S) -> Result + where + S: serde::Serializer, + { + WorkflowError::from(self).serialize(serializer) + } +} + +#[cfg(test)] +mod tests { + use super::*; + use rstest::rstest; + use serde_json::{json, to_value}; + + #[rstest] + #[case( + ReplicateStatusCause::PostComputeComputedFileNotFound, + "POST_COMPUTE_COMPUTED_FILE_NOT_FOUND", + "computed.json file missing" + )] + #[case( + ReplicateStatusCause::PostComputeDropboxUploadFailed, + "POST_COMPUTE_DROPBOX_UPLOAD_FAILED", + "Failed to upload to Dropbox" + )] + #[case( + ReplicateStatusCause::PostComputeEncryptionFailed, + "POST_COMPUTE_ENCRYPTION_FAILED", + "Encryption stage failed" + )] + #[case( + ReplicateStatusCause::PostComputeEncryptionPublicKeyMissing, + "POST_COMPUTE_ENCRYPTION_PUBLIC_KEY_MISSING", + "Encryption public key related environment variable is missing" + )] + #[case( + ReplicateStatusCause::PostComputeFailedUnknownIssue, + "POST_COMPUTE_FAILED_UNKNOWN_ISSUE", + "Unexpected error occurred" + )] + #[case( + ReplicateStatusCause::PostComputeInvalidTeeSignature, + "POST_COMPUTE_INVALID_TEE_SIGNATURE", + "Invalid TEE signature" + )] + #[case( + ReplicateStatusCause::PostComputeIpfsUploadFailed, + "POST_COMPUTE_IPFS_UPLOAD_FAILED", + "Failed to upload to IPFS" + )] + #[case( + ReplicateStatusCause::PostComputeMalformedEncryptionPublicKey, + "POST_COMPUTE_MALFORMED_ENCRYPTION_PUBLIC_KEY", + "Encryption public key is malformed" + )] + #[case( + ReplicateStatusCause::PostComputeOutFolderZipFailed, + "POST_COMPUTE_OUT_FOLDER_ZIP_FAILED", + "Failed to zip result folder" + )] + #[case( + ReplicateStatusCause::PostComputeResultDigestComputationFailed, + "POST_COMPUTE_RESULT_DIGEST_COMPUTATION_FAILED", + "Empty resultDigest" + )] + #[case( + ReplicateStatusCause::PostComputeResultFileNotFound, + "POST_COMPUTE_RESULT_FILE_NOT_FOUND", + "Result file not found" + )] + #[case( + ReplicateStatusCause::PostComputeSendComputedFileFailed, + "POST_COMPUTE_SEND_COMPUTED_FILE_FAILED", + "Failed to send computed file" + )] + #[case( + ReplicateStatusCause::PostComputeStorageTokenMissing, + "POST_COMPUTE_STORAGE_TOKEN_MISSING", + "Storage token related environment variable is missing" + )] + #[case( + ReplicateStatusCause::PostComputeTaskIdMissing, + "POST_COMPUTE_TASK_ID_MISSING", + "Task ID related environment variable is missing" + )] + #[case( + ReplicateStatusCause::PostComputeTeeChallengePrivateKeyMissing, + "POST_COMPUTE_TEE_CHALLENGE_PRIVATE_KEY_MISSING", + "Tee challenge private key related environment variable is missing" + )] + #[case( + ReplicateStatusCause::PostComputeTooLongResultFileName, + "POST_COMPUTE_TOO_LONG_RESULT_FILE_NAME", + "Result file name too long" + )] + #[case( + ReplicateStatusCause::PostComputeWorkerAddressMissing, + "POST_COMPUTE_WORKER_ADDRESS_MISSING", + "Worker address related environment variable is missing" + )] + fn error_variant_serializes_with_correct_cause_and_message( + #[case] error: ReplicateStatusCause, + #[case] expected_cause: &str, + #[case] expected_message: &str, + ) { + let serialized = to_value(&error).unwrap(); + assert_eq!( + serialized, + json!({ + "cause": expected_cause, + "message": expected_message + }) + ); + } + + #[test] + fn error_list_serializes_as_json_array() { + let errors = vec![ + ReplicateStatusCause::PostComputeComputedFileNotFound, + ReplicateStatusCause::PostComputeInvalidTeeSignature, + ReplicateStatusCause::PostComputeTaskIdMissing, + ]; + let serialized = to_value(&errors).unwrap(); + let expected = json!([ + { + "cause": "POST_COMPUTE_COMPUTED_FILE_NOT_FOUND", + "message": "computed.json file missing" + }, + { + "cause": "POST_COMPUTE_INVALID_TEE_SIGNATURE", + "message": "Invalid TEE signature" + }, + { + "cause": "POST_COMPUTE_TASK_ID_MISSING", + "message": "Task ID related environment variable is missing" + } + ]); + assert_eq!(serialized, expected); + } + + #[test] + fn empty_error_list_serializes_as_empty_json_array() { + let errors: Vec = vec![]; + let serialized = to_value(&errors).unwrap(); + assert_eq!(serialized, json!([])); + } +} diff --git a/post-compute/src/compute/utils/hash_utils.rs b/post-compute/src/compute/utils/hash_utils.rs index 6916772..54d8f01 100644 --- a/post-compute/src/compute/utils/hash_utils.rs +++ b/post-compute/src/compute/utils/hash_utils.rs @@ -17,7 +17,7 @@ pub fn hex_string_to_byte_array(input: &str) -> Vec { } let mut data: Vec = vec![]; - let start_idx = if len % 2 != 0 { + let start_idx = if !len.is_multiple_of(2) { let byte = u8::from_str_radix(&clean_input[0..1], 16).expect("Invalid hex digit in input string"); data.push(byte); From 65769cf0ea08eba3a7118d18ec300fc7621092bf Mon Sep 17 00:00:00 2001 From: Natchica Date: Thu, 16 Oct 2025 15:48:05 +0200 Subject: [PATCH 2/6] fix: simplify debug string formatting in ReplicateStatusCause --- post-compute/src/compute/errors.rs | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/post-compute/src/compute/errors.rs b/post-compute/src/compute/errors.rs index b3b8f7b..9933703 100644 --- a/post-compute/src/compute/errors.rs +++ b/post-compute/src/compute/errors.rs @@ -42,7 +42,7 @@ pub enum ReplicateStatusCause { impl ReplicateStatusCause { fn to_screaming_snake_case(&self) -> String { - let debug_str = format!("{:?}", self); + let debug_str = format!("{self:?}"); let mut result = String::new(); let mut prev_was_lowercase = false; From 8f1d9eb35819e2ece4d8882a1f93674907d822f6 Mon Sep 17 00:00:00 2001 From: Natchica Date: Tue, 21 Oct 2025 15:06:35 +0200 Subject: [PATCH 3/6] refactor: remove unused dependencies and integrate strum for error serialization in ReplicateStatusCause --- Cargo.lock | 62 ++++------- post-compute/Cargo.toml | 3 +- post-compute/src/compute/errors.rs | 158 ++++------------------------- 3 files changed, 41 insertions(+), 182 deletions(-) diff --git a/Cargo.lock b/Cargo.lock index 9d6f5b3..d22bf57 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -1700,12 +1700,6 @@ version = "0.3.31" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "f90f7dce0722e95104fcb095585910c0977252f286e354b5e3bd38902cd99988" -[[package]] -name = "futures-timer" -version = "3.0.3" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "f288b0a4f20f9a56b5d1da57e2227c661b7b16168e2f72365f57b63326e29b24" - [[package]] name = "futures-util" version = "0.3.31" @@ -3337,12 +3331,6 @@ version = "0.8.6" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "caf4aa5b0f434c91fe5c7f1ecb6a5ece2130b02ad2a590589dda5146df959001" -[[package]] -name = "relative-path" -version = "1.9.3" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "ba39f3699c378cd8970968dcbff9c43159ea4cfbd88d43c00b22f2ef10a435d2" - [[package]] name = "reqwest" version = "0.12.23" @@ -3439,35 +3427,6 @@ dependencies = [ "zeroize", ] -[[package]] -name = "rstest" -version = "0.26.1" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "f5a3193c063baaa2a95a33f03035c8a72b83d97a54916055ba22d35ed3839d49" -dependencies = [ - "futures-timer", - "futures-util", - "rstest_macros", -] - -[[package]] -name = "rstest_macros" -version = "0.26.1" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "9c845311f0ff7951c5506121a9ad75aec44d083c31583b2ea5a30bcb0b0abba0" -dependencies = [ - "cfg-if", - "glob", - "proc-macro-crate", - "proc-macro2", - "quote", - "regex", - "relative-path", - "rustc_version 0.4.1", - "syn 2.0.106", - "unicode-ident", -] - [[package]] name = "ruint" version = "1.16.0" @@ -4080,6 +4039,24 @@ dependencies = [ "syn 2.0.106", ] +[[package]] +name = "strum" +version = "0.27.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "af23d6f6c1a224baef9d3f61e287d2761385a5b88fdab4eb4c6f11aeb54c4bcf" + +[[package]] +name = "strum_macros" +version = "0.27.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "7695ce3845ea4b33927c055a39dc438a45b059f7c1b3d91d38d10355fb8cbca7" +dependencies = [ + "heck", + "proc-macro2", + "quote", + "syn 2.0.106", +] + [[package]] name = "subtle" version = "2.6.1" @@ -4184,12 +4161,13 @@ dependencies = [ "rand 0.8.5", "reqwest", "rsa", - "rstest", "serde", "serde_json", "serial_test", "sha256", "sha3", + "strum", + "strum_macros", "temp-env", "tempfile", "thiserror", diff --git a/post-compute/Cargo.toml b/post-compute/Cargo.toml index f98031e..4d1bf7e 100644 --- a/post-compute/Cargo.toml +++ b/post-compute/Cargo.toml @@ -22,6 +22,8 @@ serde = "1.0.219" serde_json = "1.0.140" sha256 = "1.6.0" sha3 = "0.10.8" +strum = "0.27.2" +strum_macros = "0.27.2" thiserror = "2.0.12" walkdir = "2.5.0" zip = "4.0.0" @@ -30,7 +32,6 @@ zip = "4.0.0" logtest = "2.0.0" mockall = "0.13.1" once_cell = "1.21.3" -rstest = "0.26.1" serial_test = "3.2.0" temp-env = "0.3.6" tempfile = "3.20.0" diff --git a/post-compute/src/compute/errors.rs b/post-compute/src/compute/errors.rs index 9933703..59212b3 100644 --- a/post-compute/src/compute/errors.rs +++ b/post-compute/src/compute/errors.rs @@ -1,8 +1,10 @@ -use serde::{Deserialize, Serialize}; +use serde::{Serializer, ser::SerializeStruct}; +use strum_macros::EnumDiscriminants; use thiserror::Error; -#[derive(Debug, PartialEq, Clone, Error, Deserialize)] -#[allow(clippy::enum_variant_names)] +#[derive(Debug, PartialEq, Clone, Error, EnumDiscriminants)] +#[strum_discriminants(derive(serde::Serialize))] +#[strum_discriminants(serde(rename_all = "SCREAMING_SNAKE_CASE"))] pub enum ReplicateStatusCause { #[error("computed.json file missing")] PostComputeComputedFileNotFound, @@ -40,153 +42,31 @@ pub enum ReplicateStatusCause { PostComputeWorkerAddressMissing, } -impl ReplicateStatusCause { - fn to_screaming_snake_case(&self) -> String { - let debug_str = format!("{self:?}"); - let mut result = String::new(); - let mut prev_was_lowercase = false; - - for c in debug_str.chars() { - if c.is_uppercase() && !result.is_empty() && prev_was_lowercase { - result.push('_'); - } - result.push(c.to_ascii_uppercase()); - prev_was_lowercase = c.is_lowercase(); - } - - result - } -} - -#[derive(Debug, Serialize)] -pub struct WorkflowError { - pub cause: String, - pub message: String, -} - -impl From<&ReplicateStatusCause> for WorkflowError { - fn from(cause: &ReplicateStatusCause) -> Self { - WorkflowError { - cause: cause.to_screaming_snake_case(), - message: cause.to_string(), - } - } -} - -impl Serialize for ReplicateStatusCause { +impl serde::Serialize for ReplicateStatusCause { fn serialize(&self, serializer: S) -> Result where - S: serde::Serializer, + S: Serializer, { - WorkflowError::from(self).serialize(serializer) + let mut state = serializer.serialize_struct("ReplicateStatusCause", 2)?; + state.serialize_field("cause", &ReplicateStatusCauseDiscriminants::from(self))?; + state.serialize_field("message", &self.to_string())?; + state.end() } } #[cfg(test)] mod tests { use super::*; - use rstest::rstest; use serde_json::{json, to_value}; - #[rstest] - #[case( - ReplicateStatusCause::PostComputeComputedFileNotFound, - "POST_COMPUTE_COMPUTED_FILE_NOT_FOUND", - "computed.json file missing" - )] - #[case( - ReplicateStatusCause::PostComputeDropboxUploadFailed, - "POST_COMPUTE_DROPBOX_UPLOAD_FAILED", - "Failed to upload to Dropbox" - )] - #[case( - ReplicateStatusCause::PostComputeEncryptionFailed, - "POST_COMPUTE_ENCRYPTION_FAILED", - "Encryption stage failed" - )] - #[case( - ReplicateStatusCause::PostComputeEncryptionPublicKeyMissing, - "POST_COMPUTE_ENCRYPTION_PUBLIC_KEY_MISSING", - "Encryption public key related environment variable is missing" - )] - #[case( - ReplicateStatusCause::PostComputeFailedUnknownIssue, - "POST_COMPUTE_FAILED_UNKNOWN_ISSUE", - "Unexpected error occurred" - )] - #[case( - ReplicateStatusCause::PostComputeInvalidTeeSignature, - "POST_COMPUTE_INVALID_TEE_SIGNATURE", - "Invalid TEE signature" - )] - #[case( - ReplicateStatusCause::PostComputeIpfsUploadFailed, - "POST_COMPUTE_IPFS_UPLOAD_FAILED", - "Failed to upload to IPFS" - )] - #[case( - ReplicateStatusCause::PostComputeMalformedEncryptionPublicKey, - "POST_COMPUTE_MALFORMED_ENCRYPTION_PUBLIC_KEY", - "Encryption public key is malformed" - )] - #[case( - ReplicateStatusCause::PostComputeOutFolderZipFailed, - "POST_COMPUTE_OUT_FOLDER_ZIP_FAILED", - "Failed to zip result folder" - )] - #[case( - ReplicateStatusCause::PostComputeResultDigestComputationFailed, - "POST_COMPUTE_RESULT_DIGEST_COMPUTATION_FAILED", - "Empty resultDigest" - )] - #[case( - ReplicateStatusCause::PostComputeResultFileNotFound, - "POST_COMPUTE_RESULT_FILE_NOT_FOUND", - "Result file not found" - )] - #[case( - ReplicateStatusCause::PostComputeSendComputedFileFailed, - "POST_COMPUTE_SEND_COMPUTED_FILE_FAILED", - "Failed to send computed file" - )] - #[case( - ReplicateStatusCause::PostComputeStorageTokenMissing, - "POST_COMPUTE_STORAGE_TOKEN_MISSING", - "Storage token related environment variable is missing" - )] - #[case( - ReplicateStatusCause::PostComputeTaskIdMissing, - "POST_COMPUTE_TASK_ID_MISSING", - "Task ID related environment variable is missing" - )] - #[case( - ReplicateStatusCause::PostComputeTeeChallengePrivateKeyMissing, - "POST_COMPUTE_TEE_CHALLENGE_PRIVATE_KEY_MISSING", - "Tee challenge private key related environment variable is missing" - )] - #[case( - ReplicateStatusCause::PostComputeTooLongResultFileName, - "POST_COMPUTE_TOO_LONG_RESULT_FILE_NAME", - "Result file name too long" - )] - #[case( - ReplicateStatusCause::PostComputeWorkerAddressMissing, - "POST_COMPUTE_WORKER_ADDRESS_MISSING", - "Worker address related environment variable is missing" - )] - fn error_variant_serializes_with_correct_cause_and_message( - #[case] error: ReplicateStatusCause, - #[case] expected_cause: &str, - #[case] expected_message: &str, - ) { - let serialized = to_value(&error).unwrap(); - assert_eq!( - serialized, - json!({ - "cause": expected_cause, - "message": expected_message - }) - ); + #[test] + fn error_variant_serialize_correctly() { + let expected = json!({ + "cause": "POST_COMPUTE_TooLongResultFileName", + "message": "Result file name too long" + }); + let error_variant = ReplicateStatusCause::PostComputeTooLongResultFileName; + assert_eq!(to_value(&error_variant).unwrap(), expected); } #[test] From 9f34cb6b3405e2982b9c8e1b05f26af43afe507a Mon Sep 17 00:00:00 2001 From: Natchica Date: Tue, 21 Oct 2025 15:32:09 +0200 Subject: [PATCH 4/6] refactor: update exit cause handling in Worker API --- post-compute/src/api/worker_api.rs | 56 +++++++++++++------------- post-compute/src/compute/app_runner.rs | 18 ++++----- post-compute/src/compute/errors.rs | 2 +- 3 files changed, 38 insertions(+), 38 deletions(-) diff --git a/post-compute/src/api/worker_api.rs b/post-compute/src/api/worker_api.rs index 488fa49..746ef2a 100644 --- a/post-compute/src/api/worker_api.rs +++ b/post-compute/src/api/worker_api.rs @@ -60,21 +60,21 @@ impl WorkerApiClient { Self::new(&base_url) } - /// Sends an exit cause for a post-compute operation to the Worker API. + /// Sends exit causes for a post-compute operation to the Worker API. /// - /// This method reports the exit cause of a post-compute operation to the Worker API, + /// This method reports the exit causes of a post-compute operation to the Worker API, /// which can be used for tracking and debugging purposes. /// /// # Arguments /// /// * `authorization` - The authorization token to use for the API request - /// * `chain_task_id` - The chain task ID for which to report the exit cause - /// * `exit_cause` - The exit cause to report + /// * `chain_task_id` - The chain task ID for which to report the exit causes + /// * `exit_causes` - The exit causes to report /// /// # Returns /// - /// * `Ok(())` - If the exit cause was successfully reported - /// * `Err(ReplicateStatusCause)` - If the exit cause could not be reported due to an HTTP error + /// * `Ok(())` - If the exit causes were successfully reported + /// * `Err(ReplicateStatusCause)` - If the exit causes could not be reported due to an HTTP error /// /// # Errors /// @@ -92,22 +92,22 @@ impl WorkerApiClient { /// let client = WorkerApiClient::new("http://worker:13100"); /// let exit_causes = vec![ReplicateStatusCause::PostComputeInvalidTeeSignature]; /// - /// match client.send_exit_cause_for_post_compute_stage( + /// match client.send_exit_causes_for_post_compute_stage( /// "authorization_token", /// "0x123456789abcdef", /// &exit_causes, /// ) { - /// Ok(()) => println!("Exit cause reported successfully"), - /// Err(error) => eprintln!("Failed to report exit cause: {}", error), + /// Ok(()) => println!("Exit causes reported successfully"), + /// Err(error) => eprintln!("Failed to report exit causes: {}", error), /// } /// ``` - pub fn send_exit_cause_for_post_compute_stage( + pub fn send_exit_causes_for_post_compute_stage( &self, authorization: &str, chain_task_id: &str, exit_causes: &[ReplicateStatusCause], ) -> Result<(), ReplicateStatusCause> { - let url = format!("{}/compute/post/{chain_task_id}/exit", self.base_url); + let url = format!("{}/compute/post/{chain_task_id}/exit-causes", self.base_url); match self .client .post(&url) @@ -122,13 +122,13 @@ impl WorkerApiClient { let status = response.status(); let body = response.text().unwrap_or_default(); error!( - "Failed to send exit cause to worker: [status:{status:?}, body:{body:#?}]" + "Failed to send exit causes to worker: [status:{status:?}, body:{body:#?}]" ); Err(ReplicateStatusCause::PostComputeFailedUnknownIssue) } } Err(e) => { - error!("An error occured while sending exit cause to worker: {e}"); + error!("An error occured while sending exit causes to worker: {e}"); Err(ReplicateStatusCause::PostComputeFailedUnknownIssue) } } @@ -232,7 +232,7 @@ mod tests { // region serialize List of ReplicateStatusCause #[test] - fn should_serialize_list_of_exit_causes() { + fn replicate_status_cause_serializes_as_json_array_when_multiple_causes() { let causes = vec![ ReplicateStatusCause::PostComputeInvalidTeeSignature, ReplicateStatusCause::PostComputeWorkerAddressMissing, @@ -243,7 +243,7 @@ mod tests { } #[test] - fn should_serialize_single_exit_cause() { + fn replicate_status_cause_serializes_as_json_array_when_single_cause() { let causes = vec![ReplicateStatusCause::PostComputeFailedUnknownIssue]; let serialized = to_string(&causes).expect("Failed to serialize"); let expected = r#"[{"cause":"POST_COMPUTE_FAILED_UNKNOWN_ISSUE","message":"Unexpected error occurred"}]"#; @@ -253,7 +253,7 @@ mod tests { // region get_worker_api_client #[test] - fn should_get_worker_api_client_with_env_var() { + fn from_env_creates_client_with_custom_url_when_env_var_set() { with_vars( vec![(WorkerHostEnvVar.name(), Some("custom-worker-host:9999"))], || { @@ -264,7 +264,7 @@ mod tests { } #[test] - fn should_get_worker_api_client_without_env_var() { + fn from_env_creates_client_with_default_url_when_env_var_missing() { with_vars(vec![(WorkerHostEnvVar.name(), None::<&str>)], || { let client = WorkerApiClient::from_env(); assert_eq!(client.base_url, format!("http://{DEFAULT_WORKER_HOST}")); @@ -272,16 +272,16 @@ mod tests { } // endregion - // region send_exit_cause_for_post_compute_stage() + // region send_exit_causes_for_post_compute_stage() #[tokio::test] - async fn should_send_exit_cause() { + async fn send_exit_causes_for_post_compute_stage_succeeds_when_server_responds_ok() { let mock_server = MockServer::start().await; let server_url = mock_server.uri(); let expected_body = json!([ReplicateStatusCause::PostComputeInvalidTeeSignature,]); Mock::given(method("POST")) - .and(path(format!("/compute/post/{CHAIN_TASK_ID}/exit"))) + .and(path(format!("/compute/post/{CHAIN_TASK_ID}/exit-causes"))) .and(header("Authorization", CHALLENGE)) .and(body_json(&expected_body)) .respond_with(ResponseTemplate::new(200)) @@ -292,7 +292,7 @@ mod tests { let result = tokio::task::spawn_blocking(move || { let exit_causes = vec![ReplicateStatusCause::PostComputeInvalidTeeSignature]; let worker_api_client = WorkerApiClient::new(&server_url); - worker_api_client.send_exit_cause_for_post_compute_stage( + worker_api_client.send_exit_causes_for_post_compute_stage( CHALLENGE, CHAIN_TASK_ID, &exit_causes, @@ -306,7 +306,7 @@ mod tests { #[tokio::test] #[serial] - async fn should_not_send_exit_cause() { + async fn send_exit_causes_for_post_compute_stage_fails_when_server_returns_404() { { let mut logger = TEST_LOGGER.lock().unwrap(); while logger.pop().is_some() {} @@ -315,7 +315,7 @@ mod tests { let server_url = mock_server.uri(); Mock::given(method("POST")) - .and(path(format!("/compute/post/{CHAIN_TASK_ID}/exit"))) + .and(path(format!("/compute/post/{CHAIN_TASK_ID}/exit-causes"))) .respond_with(ResponseTemplate::new(404)) .expect(1) .mount(&mock_server) @@ -324,7 +324,7 @@ mod tests { let result = tokio::task::spawn_blocking(move || { let exit_causes = vec![ReplicateStatusCause::PostComputeFailedUnknownIssue]; let worker_api_client = WorkerApiClient::new(&server_url); - worker_api_client.send_exit_cause_for_post_compute_stage( + worker_api_client.send_exit_causes_for_post_compute_stage( CHALLENGE, CHAIN_TASK_ID, &exit_causes, @@ -356,7 +356,7 @@ mod tests { // region send_computed_file_to_host() #[tokio::test] - async fn should_send_computed_file_successfully() { + async fn send_computed_file_to_host_succeeds_when_server_responds_ok() { let mock_server = MockServer::start().await; let server_uri = mock_server.uri(); @@ -391,7 +391,7 @@ mod tests { #[tokio::test] #[serial] - async fn should_fail_send_computed_file_on_server_error() { + async fn send_computed_file_to_host_fails_when_server_returns_500() { { let mut logger = TEST_LOGGER.lock().unwrap(); while logger.pop().is_some() {} @@ -445,7 +445,7 @@ mod tests { #[tokio::test] #[serial] - async fn should_handle_invalid_chain_task_id_in_url() { + async fn send_computed_file_to_host_fails_when_chain_task_id_invalid() { { let mut logger = TEST_LOGGER.lock().unwrap(); while logger.pop().is_some() {} @@ -486,7 +486,7 @@ mod tests { } #[tokio::test] - async fn should_send_computed_file_with_minimal_data() { + async fn send_computed_file_to_host_succeeds_when_minimal_data_provided() { let mock_server = MockServer::start().await; let server_uri = mock_server.uri(); diff --git a/post-compute/src/compute/app_runner.rs b/post-compute/src/compute/app_runner.rs index 707efc0..7199a0e 100644 --- a/post-compute/src/compute/app_runner.rs +++ b/post-compute/src/compute/app_runner.rs @@ -31,7 +31,7 @@ pub enum ExitMode { pub trait PostComputeRunnerInterface { fn run_post_compute(&self, chain_task_id: &str) -> Result<(), ReplicateStatusCause>; fn get_challenge(&self, chain_task_id: &str) -> Result; - fn send_exit_cause( + fn send_exit_causes( &self, authorization: &str, chain_task_id: &str, @@ -99,14 +99,14 @@ impl PostComputeRunnerInterface for DefaultPostComputeRunner { get_challenge(chain_task_id) } - fn send_exit_cause( + fn send_exit_causes( &self, authorization: &str, chain_task_id: &str, exit_causes: &[ReplicateStatusCause], ) -> Result<(), ReplicateStatusCause> { self.worker_api_client - .send_exit_cause_for_post_compute_stage(authorization, chain_task_id, exit_causes) + .send_exit_causes_for_post_compute_stage(authorization, chain_task_id, exit_causes) } fn send_computed_file(&self, computed_file: &ComputedFile) -> Result<(), ReplicateStatusCause> { @@ -191,10 +191,10 @@ pub fn start_with_runner(runner: &R) -> ExitMode let exit_causes = vec![exit_cause.clone()]; - match runner.send_exit_cause(&authorization, &chain_task_id, &exit_causes) { + match runner.send_exit_causes(&authorization, &chain_task_id, &exit_causes) { Ok(()) => ExitMode::ReportedFailure, Err(_) => { - error!("Failed to report exit cause [exitCause:{exit_cause}]"); + error!("Failed to report exit causes [exitCauses:{exit_causes:?}]"); ExitMode::UnreportedFailure } } @@ -273,7 +273,7 @@ mod tests { self } - fn with_send_exit_cause_failure(mut self) -> Self { + fn with_send_exit_causes_failure(mut self) -> Self { self.send_exit_cause_success = false; self } @@ -298,7 +298,7 @@ mod tests { } } - fn send_exit_cause( + fn send_exit_causes( &self, _authorization: &str, _chain_task_id: &str, @@ -423,7 +423,7 @@ mod tests { } #[test] - fn start_return_2_when_exit_cause_not_transmitted() { + fn start_return_2_when_exit_causes_not_transmitted() { with_vars( vec![( TeeSessionEnvironmentVariable::IexecTaskId.name(), @@ -434,7 +434,7 @@ mod tests { .with_run_post_compute_failure(Some( ReplicateStatusCause::PostComputeInvalidTeeSignature, )) - .with_send_exit_cause_failure(); + .with_send_exit_causes_failure(); let result = start_with_runner(&runner); assert_eq!( diff --git a/post-compute/src/compute/errors.rs b/post-compute/src/compute/errors.rs index 59212b3..6e52b59 100644 --- a/post-compute/src/compute/errors.rs +++ b/post-compute/src/compute/errors.rs @@ -62,7 +62,7 @@ mod tests { #[test] fn error_variant_serialize_correctly() { let expected = json!({ - "cause": "POST_COMPUTE_TooLongResultFileName", + "cause": "POST_COMPUTE_TOO_LONG_RESULT_FILE_NAME", "message": "Result file name too long" }); let error_variant = ReplicateStatusCause::PostComputeTooLongResultFileName; From 4111d7fed84fe0a1191798a64c2a08efb6f45661 Mon Sep 17 00:00:00 2001 From: Natchica Date: Tue, 21 Oct 2025 15:46:27 +0200 Subject: [PATCH 5/6] refactor: update error messages in ReplicateStatusCause for clarity --- post-compute/src/api/worker_api.rs | 2 +- post-compute/src/compute/errors.rs | 12 ++++++------ 2 files changed, 7 insertions(+), 7 deletions(-) diff --git a/post-compute/src/api/worker_api.rs b/post-compute/src/api/worker_api.rs index 746ef2a..265bd03 100644 --- a/post-compute/src/api/worker_api.rs +++ b/post-compute/src/api/worker_api.rs @@ -238,7 +238,7 @@ mod tests { ReplicateStatusCause::PostComputeWorkerAddressMissing, ]; let serialized = to_string(&causes).expect("Failed to serialize"); - let expected = r#"[{"cause":"POST_COMPUTE_INVALID_TEE_SIGNATURE","message":"Invalid TEE signature"},{"cause":"POST_COMPUTE_WORKER_ADDRESS_MISSING","message":"Worker address related environment variable is missing"}]"#; + let expected = r#"[{"cause":"POST_COMPUTE_INVALID_TEE_SIGNATURE","message":"Invalid TEE signature"},{"cause":"POST_COMPUTE_WORKER_ADDRESS_MISSING","message":"Worker address not found in TEE session"}]"#; assert_eq!(serialized, expected); } diff --git a/post-compute/src/compute/errors.rs b/post-compute/src/compute/errors.rs index 6e52b59..8c6c228 100644 --- a/post-compute/src/compute/errors.rs +++ b/post-compute/src/compute/errors.rs @@ -12,7 +12,7 @@ pub enum ReplicateStatusCause { PostComputeDropboxUploadFailed, #[error("Encryption stage failed")] PostComputeEncryptionFailed, - #[error("Encryption public key related environment variable is missing")] + #[error("Encryption public key not found in TEE session")] PostComputeEncryptionPublicKeyMissing, #[error("Unexpected error occurred")] PostComputeFailedUnknownIssue, @@ -30,15 +30,15 @@ pub enum ReplicateStatusCause { PostComputeResultFileNotFound, #[error("Failed to send computed file")] PostComputeSendComputedFileFailed, - #[error("Storage token related environment variable is missing")] + #[error("Storage token not found in TEE session")] PostComputeStorageTokenMissing, - #[error("Task ID related environment variable is missing")] + #[error("Task ID not found in TEE session")] PostComputeTaskIdMissing, - #[error("Tee challenge private key related environment variable is missing")] + #[error("TEE challenge private key not found in TEE session")] PostComputeTeeChallengePrivateKeyMissing, #[error("Result file name too long")] PostComputeTooLongResultFileName, - #[error("Worker address related environment variable is missing")] + #[error("Worker address not found in TEE session")] PostComputeWorkerAddressMissing, } @@ -88,7 +88,7 @@ mod tests { }, { "cause": "POST_COMPUTE_TASK_ID_MISSING", - "message": "Task ID related environment variable is missing" + "message": "Task ID not found in TEE session" } ]); assert_eq!(serialized, expected); From b2d67481dd87dbe05b6aeea8aa73f9b6cfc503df Mon Sep 17 00:00:00 2001 From: nabil-Tounarti <117689544+nabil-Tounarti@users.noreply.github.com> Date: Tue, 21 Oct 2025 16:00:30 +0200 Subject: [PATCH 6/6] feat: update ReplicateStatusCause serialization and worker_api to support new WorkflowError format (#21) --- Cargo.lock | 2 + pre-compute/Cargo.toml | 2 + pre-compute/src/api/worker_api.rs | 159 +++++++++----------- pre-compute/src/compute/app_runner.rs | 22 ++- pre-compute/src/compute/dataset.rs | 40 +++-- pre-compute/src/compute/errors.rs | 131 +++++++++++++--- pre-compute/src/compute/pre_compute_app.rs | 2 +- pre-compute/src/compute/pre_compute_args.rs | 38 ++--- pre-compute/src/compute/utils/env_utils.rs | 5 +- 9 files changed, 250 insertions(+), 151 deletions(-) diff --git a/Cargo.lock b/Cargo.lock index d22bf57..56fe159 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -4195,6 +4195,8 @@ dependencies = [ "serde_json", "sha256", "sha3", + "strum", + "strum_macros", "temp-env", "tempfile", "testcontainers", diff --git a/pre-compute/Cargo.toml b/pre-compute/Cargo.toml index 8ecd4a2..c821816 100644 --- a/pre-compute/Cargo.toml +++ b/pre-compute/Cargo.toml @@ -16,6 +16,8 @@ reqwest = { version = "0.12.15", features = ["blocking", "json"] } serde = "1.0.219" sha256 = "1.6.0" sha3 = "0.10.8" +strum = "0.27.2" +strum_macros = "0.27.2" thiserror = "2.0.12" [dev-dependencies] diff --git a/pre-compute/src/api/worker_api.rs b/pre-compute/src/api/worker_api.rs index 3aeeec4..838e332 100644 --- a/pre-compute/src/api/worker_api.rs +++ b/pre-compute/src/api/worker_api.rs @@ -4,40 +4,6 @@ use crate::compute::{ }; use log::error; use reqwest::{blocking::Client, header::AUTHORIZATION}; -use serde::Serialize; - -/// Represents payload that can be sent to the worker API to report the outcome of the -/// pre‑compute stage. -/// -/// The JSON structure expected by the REST endpoint is: -/// ```json -/// { -/// "cause": "" -/// } -/// ``` -/// -/// # Arguments -/// -/// * `cause` - A reference to the ReplicateStatusCause indicating why the pre-compute operation exited -/// -/// # Example -/// -/// ```rust -/// use tee_worker_pre_compute::api::worker_api::ExitMessage; -/// use tee_worker_pre_compute::compute::errors::ReplicateStatusCause; -/// -/// let exit_message = ExitMessage::from(&ReplicateStatusCause::PreComputeInvalidTeeSignature); -/// ``` -#[derive(Serialize, Debug)] -pub struct ExitMessage<'a> { - pub cause: &'a ReplicateStatusCause, -} - -impl<'a> From<&'a ReplicateStatusCause> for ExitMessage<'a> { - fn from(cause: &'a ReplicateStatusCause) -> Self { - Self { cause } - } -} /// Thin wrapper around a [`Client`] that knows how to reach the iExec worker API. /// @@ -93,21 +59,21 @@ impl WorkerApiClient { Self::new(&base_url) } - /// Sends an exit cause for a pre-compute operation to the Worker API. + /// Sends exit causes for a pre-compute operation to the Worker API. /// - /// This method reports the exit cause of a pre-compute operation to the Worker API, + /// This method reports the exit causes of a pre-compute operation to the Worker API, /// which can be used for tracking and debugging purposes. /// /// # Arguments /// /// * `authorization` - The authorization token to use for the API request - /// * `chain_task_id` - The chain task ID for which to report the exit cause - /// * `exit_cause` - The exit cause to report + /// * `chain_task_id` - The chain task ID for which to report the exit causes + /// * `exit_causes` - The list of exit causes to report /// /// # Returns /// - /// * `Ok(())` - If the exit cause was successfully reported - /// * `Err(Error)` - If the exit cause could not be reported due to an HTTP error + /// * `Ok(())` - If the exit causes were successfully reported + /// * `Err(Error)` - If the exit causes could not be reported due to an HTTP error /// /// # Errors /// @@ -117,33 +83,33 @@ impl WorkerApiClient { /// # Example /// /// ```rust - /// use tee_worker_pre_compute::api::worker_api::{ExitMessage, WorkerApiClient}; + /// use tee_worker_pre_compute::api::worker_api::WorkerApiClient; /// use tee_worker_pre_compute::compute::errors::ReplicateStatusCause; /// /// let client = WorkerApiClient::new("http://worker:13100"); - /// let exit_message = ExitMessage::from(&ReplicateStatusCause::PreComputeInvalidTeeSignature); + /// let exit_causes = vec![ReplicateStatusCause::PreComputeInvalidTeeSignature]; /// - /// match client.send_exit_cause_for_pre_compute_stage( + /// match client.send_exit_causes_for_pre_compute_stage( /// "authorization_token", /// "0x123456789abcdef", - /// &exit_message, + /// &exit_causes, /// ) { - /// Ok(()) => println!("Exit cause reported successfully"), - /// Err(error) => eprintln!("Failed to report exit cause: {error}"), + /// Ok(()) => println!("Exit causes reported successfully"), + /// Err(error) => eprintln!("Failed to report exit causes: {error}"), /// } /// ``` - pub fn send_exit_cause_for_pre_compute_stage( + pub fn send_exit_causes_for_pre_compute_stage( &self, authorization: &str, chain_task_id: &str, - exit_cause: &ExitMessage, + exit_causes: &[ReplicateStatusCause], ) -> Result<(), ReplicateStatusCause> { - let url = format!("{}/compute/pre/{chain_task_id}/exit", self.base_url); + let url = format!("{}/compute/pre/{chain_task_id}/exit-causes", self.base_url); match self .client .post(&url) .header(AUTHORIZATION, authorization) - .json(exit_cause) + .json(exit_causes) .send() { Ok(resp) => { @@ -152,12 +118,12 @@ impl WorkerApiClient { Ok(()) } else { let body = resp.text().unwrap_or_default(); - error!("Failed to send exit cause: [status:{status}, body:{body}]"); + error!("Failed to send exit causes: [status:{status}, body:{body}]"); Err(ReplicateStatusCause::PreComputeFailedUnknownIssue) } } Err(err) => { - error!("HTTP request failed when sending exit cause to {url}: {err:?}"); + error!("HTTP request failed when sending exit causes to {url}: {err:?}"); Err(ReplicateStatusCause::PreComputeFailedUnknownIssue) } } @@ -175,36 +141,52 @@ mod tests { matchers::{body_json, header, method, path}, }; - // region ExitMessage() + // region Serialization tests #[test] - fn should_serialize_exit_message() { - let causes = [ + fn serialize_replicate_status_cause_succeeds_when_single_cause() { + let causes = vec![ ( ReplicateStatusCause::PreComputeInvalidTeeSignature, - "PRE_COMPUTE_INVALID_TEE_SIGNATURE", + r#"{"cause":"PRE_COMPUTE_INVALID_TEE_SIGNATURE","message":"Invalid TEE signature"}"#, ), ( ReplicateStatusCause::PreComputeWorkerAddressMissing, - "PRE_COMPUTE_WORKER_ADDRESS_MISSING", + r#"{"cause":"PRE_COMPUTE_WORKER_ADDRESS_MISSING","message":"Worker address related environment variable is missing"}"#, ), ( - ReplicateStatusCause::PreComputeFailedUnknownIssue, - "PRE_COMPUTE_FAILED_UNKNOWN_ISSUE", + ReplicateStatusCause::PreComputeDatasetUrlMissing("0xDatasetAdress1".to_string()), + r#"{"cause":"PRE_COMPUTE_DATASET_URL_MISSING","message":"Dataset URL related environment variable is missing for dataset 0xDatasetAdress1"}"#, + ), + ( + ReplicateStatusCause::PreComputeInvalidDatasetChecksum( + "0xDatasetAdress2".to_string(), + ), + r#"{"cause":"PRE_COMPUTE_INVALID_DATASET_CHECKSUM","message":"Invalid dataset checksum for dataset 0xDatasetAdress2"}"#, ), ]; - for (cause, message) in causes { - let exit_message = ExitMessage::from(&cause); - let serialized = to_string(&exit_message).expect("Failed to serialize"); - let expected = format!("{{\"cause\":\"{message}\"}}"); - assert_eq!(serialized, expected); + for (cause, expected_json) in causes { + let serialized = to_string(&cause).expect("Failed to serialize"); + assert_eq!(serialized, expected_json); } } + + #[test] + fn serialize_vec_of_causes_succeeds_when_multiple_causes() { + let causes = vec![ + ReplicateStatusCause::PreComputeDatasetUrlMissing("0xDatasetAdress".to_string()), + ReplicateStatusCause::PreComputeInvalidDatasetChecksum("0xDatasetAdress".to_string()), + ]; + + let serialized = to_string(&causes).expect("Failed to serialize"); + let expected = r#"[{"cause":"PRE_COMPUTE_DATASET_URL_MISSING","message":"Dataset URL related environment variable is missing for dataset 0xDatasetAdress"},{"cause":"PRE_COMPUTE_INVALID_DATASET_CHECKSUM","message":"Invalid dataset checksum for dataset 0xDatasetAdress"}]"#; + assert_eq!(serialized, expected); + } // endregion // region get_worker_api_client #[test] - fn should_get_worker_api_client_with_env_var() { + fn from_env_creates_client_with_custom_host_when_env_var_set() { with_vars( vec![(WorkerHostEnvVar.name(), Some("custom-worker-host:9999"))], || { @@ -215,7 +197,7 @@ mod tests { } #[test] - fn should_get_worker_api_client_without_env_var() { + fn from_env_creates_client_with_default_host_when_env_var_unset() { temp_env::with_vars_unset(vec![WorkerHostEnvVar.name()], || { let client = WorkerApiClient::from_env(); assert_eq!(client.base_url, format!("http://{DEFAULT_WORKER_HOST}")); @@ -223,21 +205,24 @@ mod tests { } // endregion - // region send_exit_cause_for_pre_compute_stage() + // region send_exit_causes_for_pre_compute_stage() const CHALLENGE: &str = "challenge"; const CHAIN_TASK_ID: &str = "0x123456789abcdef"; #[tokio::test] - async fn should_send_exit_cause() { + async fn send_exit_causes_succeeds_when_api_returns_success() { let mock_server = MockServer::start().await; let server_url = mock_server.uri(); - let expected_body = json!({ - "cause": ReplicateStatusCause::PreComputeInvalidTeeSignature, - }); + let expected_body = json!([ + { + "cause": "PRE_COMPUTE_INVALID_TEE_SIGNATURE", + "message": "Invalid TEE signature" + } + ]); Mock::given(method("POST")) - .and(path(format!("/compute/pre/{CHAIN_TASK_ID}/exit"))) + .and(path(format!("/compute/pre/{CHAIN_TASK_ID}/exit-causes"))) .and(header("Authorization", CHALLENGE)) .and(body_json(&expected_body)) .respond_with(ResponseTemplate::new(200)) @@ -246,13 +231,12 @@ mod tests { .await; let result = tokio::task::spawn_blocking(move || { - let exit_message = - ExitMessage::from(&ReplicateStatusCause::PreComputeInvalidTeeSignature); + let exit_causes = vec![ReplicateStatusCause::PreComputeInvalidTeeSignature]; let worker_api_client = WorkerApiClient::new(&server_url); - worker_api_client.send_exit_cause_for_pre_compute_stage( + worker_api_client.send_exit_causes_for_pre_compute_stage( CHALLENGE, CHAIN_TASK_ID, - &exit_message, + &exit_causes, ) }) .await @@ -262,26 +246,25 @@ mod tests { } #[tokio::test] - async fn should_not_send_exit_cause() { + async fn send_exit_causes_fails_when_api_returns_error() { testing_logger::setup(); let mock_server = MockServer::start().await; let server_url = mock_server.uri(); Mock::given(method("POST")) - .and(path(format!("/compute/pre/{CHAIN_TASK_ID}/exit"))) + .and(path(format!("/compute/pre/{CHAIN_TASK_ID}/exit-causes"))) .respond_with(ResponseTemplate::new(503).set_body_string("Service Unavailable")) .expect(1) .mount(&mock_server) .await; let result = tokio::task::spawn_blocking(move || { - let exit_message = - ExitMessage::from(&ReplicateStatusCause::PreComputeFailedUnknownIssue); + let exit_causes = vec![ReplicateStatusCause::PreComputeFailedUnknownIssue]; let worker_api_client = WorkerApiClient::new(&server_url); - let response = worker_api_client.send_exit_cause_for_pre_compute_stage( + let response = worker_api_client.send_exit_causes_for_pre_compute_stage( CHALLENGE, CHAIN_TASK_ID, - &exit_message, + &exit_causes, ); testing_logger::validate(|captured_logs| { let logs = captured_logs @@ -292,7 +275,7 @@ mod tests { assert_eq!(logs.len(), 1); assert_eq!( logs[0].body, - "Failed to send exit cause: [status:503 Service Unavailable, body:Service Unavailable]" + "Failed to send exit causes: [status:503 Service Unavailable, body:Service Unavailable]" ); }); response @@ -308,14 +291,14 @@ mod tests { } #[test] - fn test_send_exit_cause_http_request_failure() { + fn send_exit_causes_fails_when_http_request_invalid() { testing_logger::setup(); - let exit_message = ExitMessage::from(&ReplicateStatusCause::PreComputeFailedUnknownIssue); + let exit_causes = vec![ReplicateStatusCause::PreComputeFailedUnknownIssue]; let worker_api_client = WorkerApiClient::new("wrong_url"); - let result = worker_api_client.send_exit_cause_for_pre_compute_stage( + let result = worker_api_client.send_exit_causes_for_pre_compute_stage( CHALLENGE, CHAIN_TASK_ID, - &exit_message, + &exit_causes, ); testing_logger::validate(|captured_logs| { let logs = captured_logs @@ -326,7 +309,7 @@ mod tests { assert_eq!(logs.len(), 1); assert_eq!( logs[0].body, - "HTTP request failed when sending exit cause to wrong_url/compute/pre/0x123456789abcdef/exit: reqwest::Error { kind: Builder, source: RelativeUrlWithoutBase }" + "HTTP request failed when sending exit causes to wrong_url/compute/pre/0x123456789abcdef/exit-causes: reqwest::Error { kind: Builder, source: RelativeUrlWithoutBase }" ); }); assert!(result.is_err()); diff --git a/pre-compute/src/compute/app_runner.rs b/pre-compute/src/compute/app_runner.rs index 40a1586..093a47b 100644 --- a/pre-compute/src/compute/app_runner.rs +++ b/pre-compute/src/compute/app_runner.rs @@ -1,4 +1,4 @@ -use crate::api::worker_api::{ExitMessage, WorkerApiClient}; +use crate::api::worker_api::WorkerApiClient; use crate::compute::pre_compute_app::{PreComputeApp, PreComputeAppTrait}; use crate::compute::{ errors::ReplicateStatusCause, @@ -61,14 +61,12 @@ pub fn start_with_app( } }; - let exit_message = ExitMessage { - cause: &exit_cause.clone(), - }; + let exit_causes = vec![exit_cause.clone()]; - match WorkerApiClient::from_env().send_exit_cause_for_pre_compute_stage( + match WorkerApiClient::from_env().send_exit_causes_for_pre_compute_stage( &authorization, chain_task_id, - &exit_message, + &exit_causes, ) { Ok(_) => ExitMode::ReportedFailure, Err(_) => { @@ -193,7 +191,7 @@ mod pre_compute_start_with_app_tests { let mock_server = MockServer::start().await; Mock::given(method("POST")) - .and(path(format!("/compute/pre/{CHAIN_TASK_ID}/exit"))) + .and(path(format!("/compute/pre/{CHAIN_TASK_ID}/exit-causes"))) .respond_with(ResponseTemplate::new(500)) .mount(&mock_server) .await; @@ -231,14 +229,14 @@ mod pre_compute_start_with_app_tests { async fn start_succeeds_when_send_exit_cause_api_success() { let mock_server = MockServer::start().await; - let expected_cause_enum = ReplicateStatusCause::PreComputeOutputFolderNotFound; - let expected_exit_message_payload = json!({ - "cause": expected_cause_enum // Relies on ReplicateStatusCause's Serialize impl - }); + let expected_exit_message_payload = json!([{ + "cause": "PRE_COMPUTE_OUTPUT_FOLDER_NOT_FOUND", + "message": "Output folder related environment variable is missing" + }]); // Mock the worker API to return success Mock::given(method("POST")) - .and(path(format!("/compute/pre/{CHAIN_TASK_ID}/exit"))) + .and(path(format!("/compute/pre/{CHAIN_TASK_ID}/exit-causes"))) .and(body_json(expected_exit_message_payload)) .respond_with(ResponseTemplate::new(200)) .expect(1) diff --git a/pre-compute/src/compute/dataset.rs b/pre-compute/src/compute/dataset.rs index 33003d0..f65543f 100644 --- a/pre-compute/src/compute/dataset.rs +++ b/pre-compute/src/compute/dataset.rs @@ -79,7 +79,9 @@ impl Dataset { } else { download_from_url(&self.url) } - .ok_or(ReplicateStatusCause::PreComputeDatasetDownloadFailed)?; + .ok_or(ReplicateStatusCause::PreComputeDatasetDownloadFailed( + self.filename.clone(), + ))?; info!("Checking encrypted dataset checksum [chainTaskId:{chain_task_id}]"); let actual_checksum = sha256_from_bytes(&encrypted_content); @@ -89,7 +91,9 @@ impl Dataset { "Invalid dataset checksum [chainTaskId:{chain_task_id}, expected:{}, actual:{actual_checksum}]", self.checksum ); - return Err(ReplicateStatusCause::PreComputeInvalidDatasetChecksum); + return Err(ReplicateStatusCause::PreComputeInvalidDatasetChecksum( + self.filename.clone(), + )); } info!("Dataset downloaded and verified successfully."); @@ -113,12 +117,14 @@ impl Dataset { &self, encrypted_content: &[u8], ) -> Result, ReplicateStatusCause> { - let key = general_purpose::STANDARD - .decode(&self.key) - .map_err(|_| ReplicateStatusCause::PreComputeDatasetDecryptionFailed)?; + let key = general_purpose::STANDARD.decode(&self.key).map_err(|_| { + ReplicateStatusCause::PreComputeDatasetDecryptionFailed(self.filename.clone()) + })?; if encrypted_content.len() < AES_IV_LENGTH || key.len() != AES_KEY_LENGTH { - return Err(ReplicateStatusCause::PreComputeDatasetDecryptionFailed); + return Err(ReplicateStatusCause::PreComputeDatasetDecryptionFailed( + self.filename.clone(), + )); } let key_slice = &key[..AES_KEY_LENGTH]; @@ -127,7 +133,9 @@ impl Dataset { Aes256CbcDec::new(key_slice.into(), iv_slice.into()) .decrypt_padded_vec_mut::(ciphertext) - .map_err(|_| ReplicateStatusCause::PreComputeDatasetDecryptionFailed) + .map_err(|_| { + ReplicateStatusCause::PreComputeDatasetDecryptionFailed(self.filename.clone()) + }) } } @@ -144,7 +152,7 @@ mod tests { "0x02a12ef127dcfbdb294a090c8f0b69a0ca30b7940fc36cabf971f488efd374d7"; const ENCRYPTED_DATASET_KEY: &str = "ubA6H9emVPJT91/flYAmnKHC0phSV3cfuqsLxQfgow0="; const HTTP_DATASET_URL: &str = "https://raw.githubusercontent.com/iExecBlockchainComputing/tee-worker-pre-compute-rust/main/src/tests_resources/encrypted-data.bin"; - const PLAIN_DATA_FILE: &str = "plain-data.txt"; + const PLAIN_DATA_FILE: &str = "0xDatasetAddress"; const IPFS_DATASET_URL: &str = "/ipfs/QmUVhChbLFiuzNK1g2GsWyWEiad7SXPqARnWzGumgziwEp"; fn get_test_dataset() -> Dataset { @@ -171,7 +179,9 @@ mod tests { let actual_content = dataset.download_encrypted_dataset(CHAIN_TASK_ID); assert_eq!( actual_content, - Err(ReplicateStatusCause::PreComputeDatasetDownloadFailed) + Err(ReplicateStatusCause::PreComputeDatasetDownloadFailed( + PLAIN_DATA_FILE.to_string() + )) ); } @@ -191,7 +201,9 @@ mod tests { let mut dataset = get_test_dataset(); dataset.url = "/ipfs/INVALID_IPFS_DATASET_URL".to_string(); let actual_content = dataset.download_encrypted_dataset(CHAIN_TASK_ID); - let expected_content = Err(ReplicateStatusCause::PreComputeDatasetDownloadFailed); + let expected_content = Err(ReplicateStatusCause::PreComputeDatasetDownloadFailed( + PLAIN_DATA_FILE.to_string(), + )); assert_eq!(actual_content, expected_content); } @@ -200,7 +212,9 @@ mod tests { let mut dataset = get_test_dataset(); dataset.checksum = "invalid_dataset_checksum".to_string(); let actual_content = dataset.download_encrypted_dataset(CHAIN_TASK_ID); - let expected_content = Err(ReplicateStatusCause::PreComputeInvalidDatasetChecksum); + let expected_content = Err(ReplicateStatusCause::PreComputeInvalidDatasetChecksum( + PLAIN_DATA_FILE.to_string(), + )); assert_eq!(actual_content, expected_content); } // endregion @@ -226,7 +240,9 @@ mod tests { assert_eq!( actual_plain_data, - Err(ReplicateStatusCause::PreComputeDatasetDecryptionFailed) + Err(ReplicateStatusCause::PreComputeDatasetDecryptionFailed( + PLAIN_DATA_FILE.to_string() + )) ); } // endregion diff --git a/pre-compute/src/compute/errors.rs b/pre-compute/src/compute/errors.rs index 51ceace..f9981de 100644 --- a/pre-compute/src/compute/errors.rs +++ b/pre-compute/src/compute/errors.rs @@ -1,24 +1,26 @@ -use serde::{Deserialize, Serialize}; +use serde::{Serializer, ser::SerializeStruct}; +use strum_macros::EnumDiscriminants; use thiserror::Error; -#[derive(Debug, PartialEq, Clone, Error, Serialize, Deserialize)] -#[serde(rename_all(serialize = "SCREAMING_SNAKE_CASE"))] +#[derive(Debug, PartialEq, Clone, Error, EnumDiscriminants)] +#[strum_discriminants(derive(serde::Serialize))] +#[strum_discriminants(serde(rename_all = "SCREAMING_SNAKE_CASE"))] #[allow(clippy::enum_variant_names)] pub enum ReplicateStatusCause { - #[error("At least one input file URL is missing")] - PreComputeAtLeastOneInputFileUrlMissing, - #[error("Dataset checksum related environment variable is missing")] - PreComputeDatasetChecksumMissing, - #[error("Failed to decrypt dataset")] - PreComputeDatasetDecryptionFailed, - #[error("Failed to download encrypted dataset file")] - PreComputeDatasetDownloadFailed, - #[error("Dataset filename related environment variable is missing")] - PreComputeDatasetFilenameMissing, - #[error("Dataset key related environment variable is missing")] - PreComputeDatasetKeyMissing, - #[error("Dataset URL related environment variable is missing")] - PreComputeDatasetUrlMissing, + #[error("input file URL {0} is missing")] + PreComputeAtLeastOneInputFileUrlMissing(usize), + #[error("Dataset checksum related environment variable is missing for dataset {0}")] + PreComputeDatasetChecksumMissing(String), + #[error("Failed to decrypt dataset {0}")] + PreComputeDatasetDecryptionFailed(String), + #[error("Failed to download encrypted dataset file for dataset {0}")] + PreComputeDatasetDownloadFailed(String), + #[error("Dataset filename related environment variable is missing for dataset {0}")] + PreComputeDatasetFilenameMissing(String), + #[error("Dataset key related environment variable is missing for dataset {0}")] + PreComputeDatasetKeyMissing(String), + #[error("Dataset URL related environment variable is missing for dataset {0}")] + PreComputeDatasetUrlMissing(String), #[error("Unexpected error occurred")] PreComputeFailedUnknownIssue, #[error("Invalid TEE signature")] @@ -29,9 +31,9 @@ pub enum ReplicateStatusCause { PreComputeInputFileDownloadFailed, #[error("Input files number related environment variable is missing")] PreComputeInputFilesNumberMissing, - #[error("Invalid dataset checksum")] - PreComputeInvalidDatasetChecksum, - #[error("Input files number related environment variable is missing")] + #[error("Invalid dataset checksum for dataset {0}")] + PreComputeInvalidDatasetChecksum(String), + #[error("Output folder related environment variable is missing")] PreComputeOutputFolderNotFound, #[error("Output path related environment variable is missing")] PreComputeOutputPathMissing, @@ -44,3 +46,92 @@ pub enum ReplicateStatusCause { #[error("Worker address related environment variable is missing")] PreComputeWorkerAddressMissing, } + +impl serde::Serialize for ReplicateStatusCause { + fn serialize(&self, serializer: S) -> Result + where + S: Serializer, + { + let mut state = serializer.serialize_struct("ReplicateStatusCause", 2)?; + state.serialize_field("cause", &ReplicateStatusCauseDiscriminants::from(self))?; + state.serialize_field("message", &self.to_string())?; + state.end() + } +} + +#[cfg(test)] +mod tests { + use super::*; + use serde_json::to_string; + + const DATASET_FILENAME: &str = "0xDatasetAddress"; + + #[test] + fn serialize_produces_correct_json_when_error_has_dataset_filename() { + let cause = ReplicateStatusCause::PreComputeDatasetUrlMissing(DATASET_FILENAME.to_string()); + let serialized = to_string(&cause).unwrap(); + assert_eq!( + serialized, + r#"{"cause":"PRE_COMPUTE_DATASET_URL_MISSING","message":"Dataset URL related environment variable is missing for dataset 0xDatasetAddress"}"# + ); + } + + #[test] + fn serialize_produces_correct_json_when_error_has_no_index() { + let cause = ReplicateStatusCause::PreComputeInvalidTeeSignature; + let serialized = to_string(&cause).unwrap(); + assert_eq!( + serialized, + r#"{"cause":"PRE_COMPUTE_INVALID_TEE_SIGNATURE","message":"Invalid TEE signature"}"# + ); + } + + #[test] + fn serialize_produces_correct_json_when_multiple_dataset_errors_with_filenames() { + let test_cases = vec![ + ( + ReplicateStatusCause::PreComputeAtLeastOneInputFileUrlMissing(1), + r#"{"cause":"PRE_COMPUTE_AT_LEAST_ONE_INPUT_FILE_URL_MISSING","message":"input file URL 1 is missing"}"#, + ), + ( + ReplicateStatusCause::PreComputeDatasetChecksumMissing( + DATASET_FILENAME.to_string(), + ), + r#"{"cause":"PRE_COMPUTE_DATASET_CHECKSUM_MISSING","message":"Dataset checksum related environment variable is missing for dataset 0xDatasetAddress"}"#, + ), + ( + ReplicateStatusCause::PreComputeDatasetDecryptionFailed( + DATASET_FILENAME.to_string(), + ), + r#"{"cause":"PRE_COMPUTE_DATASET_DECRYPTION_FAILED","message":"Failed to decrypt dataset 0xDatasetAddress"}"#, + ), + ( + ReplicateStatusCause::PreComputeDatasetDownloadFailed(DATASET_FILENAME.to_string()), + r#"{"cause":"PRE_COMPUTE_DATASET_DOWNLOAD_FAILED","message":"Failed to download encrypted dataset file for dataset 0xDatasetAddress"}"#, + ), + ( + ReplicateStatusCause::PreComputeInvalidDatasetChecksum( + DATASET_FILENAME.to_string(), + ), + r#"{"cause":"PRE_COMPUTE_INVALID_DATASET_CHECKSUM","message":"Invalid dataset checksum for dataset 0xDatasetAddress"}"#, + ), + ]; + + for (cause, expected) in test_cases { + let serialized = to_string(&cause).unwrap(); + assert_eq!(serialized, expected); + } + } + + #[test] + fn serialize_produces_correct_json_when_vector_of_multiple_errors() { + let causes = vec![ + ReplicateStatusCause::PreComputeDatasetUrlMissing(DATASET_FILENAME.to_string()), + ReplicateStatusCause::PreComputeInvalidDatasetChecksum("0xAnotherDataset".to_string()), + ]; + + let serialized = to_string(&causes).unwrap(); + let expected = r#"[{"cause":"PRE_COMPUTE_DATASET_URL_MISSING","message":"Dataset URL related environment variable is missing for dataset 0xDatasetAddress"},{"cause":"PRE_COMPUTE_INVALID_DATASET_CHECKSUM","message":"Invalid dataset checksum for dataset 0xAnotherDataset"}]"#; + assert_eq!(serialized, expected); + } +} diff --git a/pre-compute/src/compute/pre_compute_app.rs b/pre-compute/src/compute/pre_compute_app.rs index ca12b10..825f3ef 100644 --- a/pre-compute/src/compute/pre_compute_app.rs +++ b/pre-compute/src/compute/pre_compute_app.rs @@ -59,7 +59,7 @@ impl PreComputeAppTrait for PreComputeApp { // TODO: Collect all errors instead of propagating immediately, and return the list of errors self.pre_compute_args = PreComputeArgs::read_args()?; self.check_output_folder()?; - for dataset in &self.pre_compute_args.datasets { + for dataset in self.pre_compute_args.datasets.iter() { let encrypted_content = dataset.download_encrypted_dataset(&self.chain_task_id)?; let plain_content = dataset.decrypt_dataset(&encrypted_content)?; self.save_plain_dataset_file(&plain_content, &dataset.filename)?; diff --git a/pre-compute/src/compute/pre_compute_args.rs b/pre-compute/src/compute/pre_compute_args.rs index 1bd074c..e230afb 100644 --- a/pre-compute/src/compute/pre_compute_args.rs +++ b/pre-compute/src/compute/pre_compute_args.rs @@ -86,21 +86,21 @@ impl PreComputeArgs { // Read datasets let start_index = if is_dataset_required { 0 } else { 1 }; for i in start_index..=iexec_bulk_slice_size { + let filename = get_env_var_or_error( + TeeSessionEnvironmentVariable::IexecDatasetFilename(i), + ReplicateStatusCause::PreComputeDatasetFilenameMissing(format!("dataset_{i}")), + )?; let url = get_env_var_or_error( TeeSessionEnvironmentVariable::IexecDatasetUrl(i), - ReplicateStatusCause::PreComputeDatasetUrlMissing, // TODO: replace with a more specific error for bulk dataset + ReplicateStatusCause::PreComputeDatasetUrlMissing(filename.clone()), )?; let checksum = get_env_var_or_error( TeeSessionEnvironmentVariable::IexecDatasetChecksum(i), - ReplicateStatusCause::PreComputeDatasetChecksumMissing, // TODO: replace with a more specific error for bulk dataset - )?; - let filename = get_env_var_or_error( - TeeSessionEnvironmentVariable::IexecDatasetFilename(i), - ReplicateStatusCause::PreComputeDatasetFilenameMissing, // TODO: replace with a more specific error for bulk dataset + ReplicateStatusCause::PreComputeDatasetChecksumMissing(filename.clone()), )?; let key = get_env_var_or_error( TeeSessionEnvironmentVariable::IexecDatasetKey(i), - ReplicateStatusCause::PreComputeDatasetKeyMissing, // TODO: replace with a more specific error for bulk dataset + ReplicateStatusCause::PreComputeDatasetKeyMissing(filename.clone()), )?; datasets.push(Dataset::new(url, checksum, filename, key)); @@ -118,7 +118,7 @@ impl PreComputeArgs { for i in 1..=input_files_nb { let url = get_env_var_or_error( TeeSessionEnvironmentVariable::IexecInputFileUrlPrefix(i), - ReplicateStatusCause::PreComputeAtLeastOneInputFileUrlMissing, + ReplicateStatusCause::PreComputeAtLeastOneInputFileUrlMissing(i), )?; input_files.push(url); } @@ -427,7 +427,7 @@ mod tests { assert!(result.is_err()); assert_eq!( result.unwrap_err(), - ReplicateStatusCause::PreComputeDatasetUrlMissing + ReplicateStatusCause::PreComputeDatasetUrlMissing("bulk-dataset-1.txt".to_string()) ); }); } @@ -446,7 +446,9 @@ mod tests { assert!(result.is_err()); assert_eq!( result.unwrap_err(), - ReplicateStatusCause::PreComputeDatasetChecksumMissing + ReplicateStatusCause::PreComputeDatasetChecksumMissing( + "bulk-dataset-2.txt".to_string() + ) ); }); } @@ -465,7 +467,7 @@ mod tests { assert!(result.is_err()); assert_eq!( result.unwrap_err(), - ReplicateStatusCause::PreComputeDatasetFilenameMissing + ReplicateStatusCause::PreComputeDatasetFilenameMissing("dataset_2".to_string()) ); }); } @@ -484,7 +486,7 @@ mod tests { assert!(result.is_err()); assert_eq!( result.unwrap_err(), - ReplicateStatusCause::PreComputeDatasetKeyMissing + ReplicateStatusCause::PreComputeDatasetKeyMissing("bulk-dataset-1.txt".to_string()) ); }); } @@ -508,23 +510,25 @@ mod tests { ), ( IexecDatasetUrl(0), - ReplicateStatusCause::PreComputeDatasetUrlMissing, + ReplicateStatusCause::PreComputeDatasetUrlMissing(DATASET_FILENAME.to_string()), ), ( IexecDatasetKey(0), - ReplicateStatusCause::PreComputeDatasetKeyMissing, + ReplicateStatusCause::PreComputeDatasetKeyMissing(DATASET_FILENAME.to_string()), ), ( IexecDatasetChecksum(0), - ReplicateStatusCause::PreComputeDatasetChecksumMissing, + ReplicateStatusCause::PreComputeDatasetChecksumMissing( + DATASET_FILENAME.to_string(), + ), ), ( IexecDatasetFilename(0), - ReplicateStatusCause::PreComputeDatasetFilenameMissing, + ReplicateStatusCause::PreComputeDatasetFilenameMissing("dataset_0".to_string()), ), ( IexecInputFileUrlPrefix(1), - ReplicateStatusCause::PreComputeAtLeastOneInputFileUrlMissing, + ReplicateStatusCause::PreComputeAtLeastOneInputFileUrlMissing(1), ), ]; for (env_var, error) in missing_env_var_causes { diff --git a/pre-compute/src/compute/utils/env_utils.rs b/pre-compute/src/compute/utils/env_utils.rs index 270f0d6..72598d5 100644 --- a/pre-compute/src/compute/utils/env_utils.rs +++ b/pre-compute/src/compute/utils/env_utils.rs @@ -71,6 +71,8 @@ mod tests { use super::*; use temp_env; + const DATASET_ADDRESS: &str = "0xDatasetAddress"; + #[test] fn name_succeeds_when_simple_environment_variable_names() { assert_eq!( @@ -202,7 +204,8 @@ mod tests { #[test] fn get_env_var_or_error_succeeds_when_indexed_variables() { let env_var = TeeSessionEnvironmentVariable::IexecDatasetChecksum(1); - let status_cause = ReplicateStatusCause::PreComputeDatasetChecksumMissing; + let status_cause = + ReplicateStatusCause::PreComputeDatasetChecksumMissing(DATASET_ADDRESS.to_string()); temp_env::with_var("IEXEC_DATASET_1_CHECKSUM", Some("abc123def456"), || { let result = get_env_var_or_error(env_var, status_cause.clone());