diff --git a/Cargo.lock b/Cargo.lock index fd61e06..c3f7c26 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -4128,7 +4128,7 @@ checksum = "55937e1799185b12863d447f42597ed69d9928686b8d88a1df17376a097d8369" [[package]] name = "tee-worker-post-compute" -version = "0.1.0" +version = "0.2.0" dependencies = [ "aes", "alloy-signer", @@ -4159,7 +4159,7 @@ dependencies = [ [[package]] name = "tee-worker-pre-compute" -version = "0.1.0" +version = "0.2.0" dependencies = [ "aes", "alloy-signer", diff --git a/pre-compute/src/compute.rs b/pre-compute/src/compute.rs index ed1c48f..a39f08f 100644 --- a/pre-compute/src/compute.rs +++ b/pre-compute/src/compute.rs @@ -1,4 +1,5 @@ pub mod app_runner; +pub mod dataset; pub mod errors; pub mod pre_compute_app; pub mod pre_compute_args; diff --git a/pre-compute/src/compute/dataset.rs b/pre-compute/src/compute/dataset.rs new file mode 100644 index 0000000..33003d0 --- /dev/null +++ b/pre-compute/src/compute/dataset.rs @@ -0,0 +1,233 @@ +use crate::compute::errors::ReplicateStatusCause; +use crate::compute::utils::file_utils::download_from_url; +use crate::compute::utils::hash_utils::sha256_from_bytes; +use aes::Aes256; +use base64::{Engine as _, engine::general_purpose}; +use cbc::{ + Decryptor, + cipher::{BlockDecryptMut, KeyIvInit, block_padding::Pkcs7}, +}; +use log::{error, info}; +use multiaddr::Multiaddr; +use std::str::FromStr; + +type Aes256CbcDec = Decryptor; +const IPFS_GATEWAYS: &[&str] = &[ + "https://ipfs-gateway.v8-bellecour.iex.ec", + "https://gateway.ipfs.io", + "https://gateway.pinata.cloud", +]; +const AES_KEY_LENGTH: usize = 32; +const AES_IV_LENGTH: usize = 16; + +/// Represents a dataset in a Trusted Execution Environment (TEE). +/// +/// This structure contains all the information needed to download, verify, and decrypt +/// a single dataset. +#[cfg_attr(test, derive(Debug))] +#[derive(Clone, Default)] +pub struct Dataset { + pub url: String, + pub checksum: String, + pub filename: String, + pub key: String, +} + +impl Dataset { + pub fn new(url: String, checksum: String, filename: String, key: String) -> Self { + Dataset { + url, + checksum, + filename, + key, + } + } + + /// Downloads the encrypted dataset file from a URL or IPFS multi-address, and verifies its checksum. + /// + /// # Arguments + /// + /// * `chain_task_id` - The chain task ID for logging + /// + /// # Returns + /// + /// * `Ok(Vec)` containing the dataset's encrypted content if download and verification succeed. + /// * `Err(ReplicateStatusCause::PreComputeDatasetDownloadFailed)` if the download fails. + /// * `Err(ReplicateStatusCause::PreComputeInvalidDatasetChecksum)` if checksum validation fails. + pub fn download_encrypted_dataset( + &self, + chain_task_id: &str, + ) -> Result, ReplicateStatusCause> { + info!( + "Downloading encrypted dataset file [chainTaskId:{chain_task_id}, url:{}]", + self.url + ); + + let encrypted_content = if is_multi_address(&self.url) { + IPFS_GATEWAYS.iter().find_map(|gateway| { + let full_url = format!("{gateway}{}", self.url); + info!("Attempting to download dataset from {full_url}"); + + if let Some(content) = download_from_url(&full_url) { + info!("Successfully downloaded from {full_url}"); + Some(content) + } else { + error!("Failed to download from {full_url}"); + None + } + }) + } else { + download_from_url(&self.url) + } + .ok_or(ReplicateStatusCause::PreComputeDatasetDownloadFailed)?; + + info!("Checking encrypted dataset checksum [chainTaskId:{chain_task_id}]"); + let actual_checksum = sha256_from_bytes(&encrypted_content); + + if actual_checksum != self.checksum { + error!( + "Invalid dataset checksum [chainTaskId:{chain_task_id}, expected:{}, actual:{actual_checksum}]", + self.checksum + ); + return Err(ReplicateStatusCause::PreComputeInvalidDatasetChecksum); + } + + info!("Dataset downloaded and verified successfully."); + Ok(encrypted_content) + } + + /// Decrypts the provided encrypted dataset bytes using AES-CBC. + /// + /// The first 16 bytes of `encrypted_content` are treated as the IV. + /// The rest is the ciphertext. The decryption key is decoded from a Base64 string. + /// + /// # Arguments + /// + /// * `encrypted_content` - Full encrypted dataset, including the IV prefix. + /// + /// # Returns + /// + /// * `Ok(Vec)` containing the plaintext dataset if decryption succeeds. + /// * `Err(ReplicateStatusCause::PreComputeDatasetDecryptionFailed)` if the key is missing, decoding fails, or decryption fails. + pub fn decrypt_dataset( + &self, + encrypted_content: &[u8], + ) -> Result, ReplicateStatusCause> { + let key = general_purpose::STANDARD + .decode(&self.key) + .map_err(|_| ReplicateStatusCause::PreComputeDatasetDecryptionFailed)?; + + if encrypted_content.len() < AES_IV_LENGTH || key.len() != AES_KEY_LENGTH { + return Err(ReplicateStatusCause::PreComputeDatasetDecryptionFailed); + } + + let key_slice = &key[..AES_KEY_LENGTH]; + let iv_slice = &encrypted_content[..AES_IV_LENGTH]; + let ciphertext = &encrypted_content[AES_IV_LENGTH..]; + + Aes256CbcDec::new(key_slice.into(), iv_slice.into()) + .decrypt_padded_vec_mut::(ciphertext) + .map_err(|_| ReplicateStatusCause::PreComputeDatasetDecryptionFailed) + } +} + +fn is_multi_address(uri: &str) -> bool { + !uri.trim().is_empty() && Multiaddr::from_str(uri).is_ok() +} + +#[cfg(test)] +mod tests { + use super::*; + + const CHAIN_TASK_ID: &str = "0x123456789abcdef"; + const DATASET_CHECKSUM: &str = + "0x02a12ef127dcfbdb294a090c8f0b69a0ca30b7940fc36cabf971f488efd374d7"; + const ENCRYPTED_DATASET_KEY: &str = "ubA6H9emVPJT91/flYAmnKHC0phSV3cfuqsLxQfgow0="; + const HTTP_DATASET_URL: &str = "https://raw.githubusercontent.com/iExecBlockchainComputing/tee-worker-pre-compute-rust/main/src/tests_resources/encrypted-data.bin"; + const PLAIN_DATA_FILE: &str = "plain-data.txt"; + const IPFS_DATASET_URL: &str = "/ipfs/QmUVhChbLFiuzNK1g2GsWyWEiad7SXPqARnWzGumgziwEp"; + + fn get_test_dataset() -> Dataset { + Dataset::new( + HTTP_DATASET_URL.to_string(), + DATASET_CHECKSUM.to_string(), + PLAIN_DATA_FILE.to_string(), + ENCRYPTED_DATASET_KEY.to_string(), + ) + } + + // region download_encrypted_dataset + #[test] + fn download_encrypted_dataset_success() { + let dataset = get_test_dataset(); + let actual_content = dataset.download_encrypted_dataset(CHAIN_TASK_ID); + assert!(actual_content.is_ok()); + } + + #[test] + fn download_encrypted_dataset_failure_with_invalid_dataset_url() { + let mut dataset = get_test_dataset(); + dataset.url = "http://bad-url".to_string(); + let actual_content = dataset.download_encrypted_dataset(CHAIN_TASK_ID); + assert_eq!( + actual_content, + Err(ReplicateStatusCause::PreComputeDatasetDownloadFailed) + ); + } + + #[test] + fn download_encrypted_dataset_success_with_valid_iexec_gateway() { + let mut dataset = get_test_dataset(); + dataset.url = IPFS_DATASET_URL.to_string(); + dataset.checksum = + "0x323b1637c7999942fbebfe5d42fe15dbfe93737577663afa0181938d7ad4a2ac".to_string(); + let actual_content = dataset.download_encrypted_dataset(CHAIN_TASK_ID); + let expected_content = Ok("hello world !\n".as_bytes().to_vec()); + assert_eq!(actual_content, expected_content); + } + + #[test] + fn download_encrypted_dataset_failure_with_invalid_gateway() { + let mut dataset = get_test_dataset(); + dataset.url = "/ipfs/INVALID_IPFS_DATASET_URL".to_string(); + let actual_content = dataset.download_encrypted_dataset(CHAIN_TASK_ID); + let expected_content = Err(ReplicateStatusCause::PreComputeDatasetDownloadFailed); + assert_eq!(actual_content, expected_content); + } + + #[test] + fn download_encrypted_dataset_failure_with_invalid_dataset_checksum() { + let mut dataset = get_test_dataset(); + dataset.checksum = "invalid_dataset_checksum".to_string(); + let actual_content = dataset.download_encrypted_dataset(CHAIN_TASK_ID); + let expected_content = Err(ReplicateStatusCause::PreComputeInvalidDatasetChecksum); + assert_eq!(actual_content, expected_content); + } + // endregion + + // region decrypt_dataset + #[test] + fn decrypt_dataset_success_with_valid_dataset() { + let dataset = get_test_dataset(); + + let encrypted_data = dataset.download_encrypted_dataset(CHAIN_TASK_ID).unwrap(); + let expected_plain_data = Ok("Some very useful data.".as_bytes().to_vec()); + let actual_plain_data = dataset.decrypt_dataset(&encrypted_data); + + assert_eq!(actual_plain_data, expected_plain_data); + } + + #[test] + fn decrypt_dataset_failure_with_bad_key() { + let mut dataset = get_test_dataset(); + dataset.key = "bad_key".to_string(); + let encrypted_data = dataset.download_encrypted_dataset(CHAIN_TASK_ID).unwrap(); + let actual_plain_data = dataset.decrypt_dataset(&encrypted_data); + + assert_eq!( + actual_plain_data, + Err(ReplicateStatusCause::PreComputeDatasetDecryptionFailed) + ); + } + // endregion +} diff --git a/pre-compute/src/compute/pre_compute_app.rs b/pre-compute/src/compute/pre_compute_app.rs index 48a82a7..829ded0 100644 --- a/pre-compute/src/compute/pre_compute_app.rs +++ b/pre-compute/src/compute/pre_compute_app.rs @@ -1,37 +1,22 @@ use crate::compute::errors::ReplicateStatusCause; use crate::compute::pre_compute_args::PreComputeArgs; -use crate::compute::utils::file_utils::{download_file, download_from_url, write_file}; -use crate::compute::utils::hash_utils::{sha256, sha256_from_bytes}; -use aes::Aes256; -use base64::{Engine as _, engine::general_purpose}; -use cbc::{ - Decryptor, - cipher::{BlockDecryptMut, KeyIvInit, block_padding::Pkcs7}, -}; +use crate::compute::utils::file_utils::{download_file, write_file}; +use crate::compute::utils::hash_utils::sha256; use log::{error, info}; #[cfg(test)] use mockall::automock; -use multiaddr::Multiaddr; use std::path::{Path, PathBuf}; -use std::str::FromStr; - -type Aes256CbcDec = Decryptor; -const IPFS_GATEWAYS: &[&str] = &[ - "https://ipfs-gateway.v8-bellecour.iex.ec", - "https://gateway.ipfs.io", - "https://gateway.pinata.cloud", -]; -const AES_KEY_LENGTH: usize = 32; -const AES_IV_LENGTH: usize = 16; #[cfg_attr(test, automock)] pub trait PreComputeAppTrait { fn run(&mut self) -> Result<(), ReplicateStatusCause>; fn check_output_folder(&self) -> Result<(), ReplicateStatusCause>; fn download_input_files(&self) -> Result<(), ReplicateStatusCause>; - fn download_encrypted_dataset(&self) -> Result, ReplicateStatusCause>; - fn decrypt_dataset(&self, encrypted_content: &[u8]) -> Result, ReplicateStatusCause>; - fn save_plain_dataset_file(&self, plain_content: &[u8]) -> Result<(), ReplicateStatusCause>; + fn save_plain_dataset_file( + &self, + plain_content: &[u8], + plain_dataset_filename: &str, + ) -> Result<(), ReplicateStatusCause>; } pub struct PreComputeApp { @@ -71,12 +56,13 @@ impl PreComputeAppTrait for PreComputeApp { /// app.run(); /// ``` fn run(&mut self) -> Result<(), ReplicateStatusCause> { + // TODO: Collect all errors instead of propagating immediately, and return the list of errors self.pre_compute_args = PreComputeArgs::read_args()?; self.check_output_folder()?; - if self.pre_compute_args.is_dataset_required { - let encrypted_content = self.download_encrypted_dataset()?; - let plain_content = self.decrypt_dataset(&encrypted_content)?; - self.save_plain_dataset_file(&plain_content)?; + for dataset in &self.pre_compute_args.datasets { + let encrypted_content = dataset.download_encrypted_dataset(&self.chain_task_id)?; + let plain_content = dataset.decrypt_dataset(&encrypted_content)?; + self.save_plain_dataset_file(&plain_content, &dataset.filename)?; } self.download_input_files()?; Ok(()) @@ -134,88 +120,6 @@ impl PreComputeAppTrait for PreComputeApp { Ok(()) } - /// Downloads the encrypted dataset file from a URL or IPFS multi-address, and verifies its checksum. - /// - /// # Returns - /// - /// * `Ok(Vec)` containing the dataset's encrypted content if download and verification succeed. - /// * `Err(ReplicateStatusCause::PreComputeDatasetDownloadFailed)` if the download fails or inputs are missing. - /// * `Err(ReplicateStatusCause::PreComputeInvalidDatasetChecksum)` if checksum validation fails. - fn download_encrypted_dataset(&self) -> Result, ReplicateStatusCause> { - let args = &self.pre_compute_args; - let chain_task_id = &self.chain_task_id; - let encrypted_dataset_url: &str = &args.encrypted_dataset_url; - - info!( - "Downloading encrypted dataset file [chainTaskId:{chain_task_id}, url:{encrypted_dataset_url}]", - ); - - let encrypted_content = if is_multi_address(encrypted_dataset_url) { - IPFS_GATEWAYS.iter().find_map(|gateway| { - let full_url = format!("{gateway}{encrypted_dataset_url}"); - info!("Attempting to download dataset from {full_url}"); - - if let Some(content) = download_from_url(&full_url) { - info!("Successfully downloaded from {full_url}"); - Some(content) - } else { - info!("Failed to download from {full_url}"); - None - } - }) - } else { - download_from_url(encrypted_dataset_url) - } - .ok_or(ReplicateStatusCause::PreComputeDatasetDownloadFailed)?; - - info!("Checking encrypted dataset checksum [chainTaskId:{chain_task_id}]"); - let expected_checksum: &str = &args.encrypted_dataset_checksum; - let actual_checksum = sha256_from_bytes(&encrypted_content); - - if actual_checksum != expected_checksum { - error!( - "Invalid dataset checksum [chainTaskId:{chain_task_id}, expected:{expected_checksum}, actual:{actual_checksum}]" - ); - return Err(ReplicateStatusCause::PreComputeInvalidDatasetChecksum); - } - - info!("Dataset downloaded and verified successfully."); - Ok(encrypted_content) - } - - /// Decrypts the provided encrypted dataset bytes using AES-CBC. - /// - /// The first 16 bytes of `encrypted_content` are treated as the IV. - /// The rest is the ciphertext. The decryption key is decoded from a Base64 string. - /// - /// # Arguments - /// - /// * `encrypted_content` - Full encrypted dataset, including the IV prefix. - /// - /// # Returns - /// - /// * `Ok(Vec)` containing the plaintext dataset if decryption succeeds. - /// * `Err(ReplicateStatusCause::PreComputeDatasetDecryptionFailed)` if the key is missing, decoding fails, or decryption fails. - fn decrypt_dataset(&self, encrypted_content: &[u8]) -> Result, ReplicateStatusCause> { - let base64_key: &str = &self.pre_compute_args.encrypted_dataset_base64_key; - - let key = general_purpose::STANDARD - .decode(base64_key) - .map_err(|_| ReplicateStatusCause::PreComputeDatasetDecryptionFailed)?; - - if encrypted_content.len() < AES_IV_LENGTH || key.len() != AES_KEY_LENGTH { - return Err(ReplicateStatusCause::PreComputeDatasetDecryptionFailed); - } - - let key_slice = &key[..AES_KEY_LENGTH]; - let iv_slice = &encrypted_content[..AES_IV_LENGTH]; - let ciphertext = &encrypted_content[AES_IV_LENGTH..]; - - Aes256CbcDec::new(key_slice.into(), iv_slice.into()) - .decrypt_padded_vec_mut::(ciphertext) - .map_err(|_| ReplicateStatusCause::PreComputeDatasetDecryptionFailed) - } - /// Saves the decrypted (plain) dataset to disk in the configured output directory. /// /// The output filename is taken from `pre_compute_args.plain_dataset_filename`. @@ -228,11 +132,14 @@ impl PreComputeAppTrait for PreComputeApp { /// /// * `Ok(())` if the file is successfully saved. /// * `Err(ReplicateStatusCause::PreComputeSavingPlainDatasetFailed)` if the path is invalid or write fails. - fn save_plain_dataset_file(&self, plain_dataset: &[u8]) -> Result<(), ReplicateStatusCause> { + fn save_plain_dataset_file( + &self, + plain_dataset: &[u8], + plain_dataset_filename: &str, + ) -> Result<(), ReplicateStatusCause> { let chain_task_id: &str = &self.chain_task_id; let args = &self.pre_compute_args; let output_dir: &str = &args.output_dir; - let plain_dataset_filename: &str = &args.plain_dataset_filename; let mut path = PathBuf::from(output_dir); path.push(plain_dataset_filename); @@ -251,13 +158,10 @@ impl PreComputeAppTrait for PreComputeApp { } } -fn is_multi_address(uri: &str) -> bool { - !uri.trim().is_empty() && Multiaddr::from_str(uri).is_ok() -} - #[cfg(test)] mod tests { use super::*; + use crate::compute::dataset::Dataset; use crate::compute::pre_compute_args::PreComputeArgs; use std::fs; use tempfile::TempDir; @@ -270,7 +174,6 @@ mod tests { "0x02a12ef127dcfbdb294a090c8f0b69a0ca30b7940fc36cabf971f488efd374d7"; const ENCRYPTED_DATASET_KEY: &str = "ubA6H9emVPJT91/flYAmnKHC0phSV3cfuqsLxQfgow0="; const HTTP_DATASET_URL: &str = "https://raw.githubusercontent.com/iExecBlockchainComputing/tee-worker-pre-compute-rust/main/src/tests_resources/encrypted-data.bin"; - const IPFS_DATASET_URL: &str = "/ipfs/QmUVhChbLFiuzNK1g2GsWyWEiad7SXPqARnWzGumgziwEp"; const PLAIN_DATA_FILE: &str = "plain-data.txt"; fn get_pre_compute_app( @@ -284,10 +187,13 @@ mod tests { input_files: urls.into_iter().map(String::from).collect(), output_dir: output_dir.to_string(), is_dataset_required: true, - encrypted_dataset_url: HTTP_DATASET_URL.to_string(), - encrypted_dataset_base64_key: ENCRYPTED_DATASET_KEY.to_string(), - encrypted_dataset_checksum: DATASET_CHECKSUM.to_string(), - plain_dataset_filename: PLAIN_DATA_FILE.to_string(), + bulk_size: 0, + datasets: vec![Dataset { + url: HTTP_DATASET_URL.to_string(), + checksum: DATASET_CHECKSUM.to_string(), + filename: PLAIN_DATA_FILE.to_string(), + key: ENCRYPTED_DATASET_KEY.to_string(), + }], }, } } @@ -422,84 +328,6 @@ mod tests { } // endregion - // region download_encrypted_dataset - #[test] - fn download_encrypted_dataset_success_with_valid_dataset_url() { - let app = get_pre_compute_app(CHAIN_TASK_ID, vec![], ""); - - let actual_content = app.download_encrypted_dataset(); - let expected_content = download_from_url(HTTP_DATASET_URL) - .ok_or(ReplicateStatusCause::PreComputeDatasetDownloadFailed); - assert_eq!(actual_content, expected_content); - } - - #[test] - fn download_encrypted_dataset_failure_with_invalid_dataset_url() { - let mut app = get_pre_compute_app(CHAIN_TASK_ID, vec![], ""); - app.pre_compute_args.encrypted_dataset_url = "http://bad-url".to_string(); - let actual_content = app.download_encrypted_dataset(); - assert_eq!( - actual_content, - Err(ReplicateStatusCause::PreComputeDatasetDownloadFailed) - ); - } - - #[test] - fn download_encrypted_dataset_success_with_valid_iexec_gateway() { - let mut app = get_pre_compute_app(CHAIN_TASK_ID, vec![], ""); - app.pre_compute_args.encrypted_dataset_url = IPFS_DATASET_URL.to_string(); - app.pre_compute_args.encrypted_dataset_checksum = - "0x323b1637c7999942fbebfe5d42fe15dbfe93737577663afa0181938d7ad4a2ac".to_string(); - let actual_content = app.download_encrypted_dataset(); - let expected_content = Ok("hello world !\n".as_bytes().to_vec()); - assert_eq!(actual_content, expected_content); - } - - #[test] - fn download_encrypted_dataset_failure_with_invalid_gateway() { - let mut app = get_pre_compute_app(CHAIN_TASK_ID, vec![], ""); - app.pre_compute_args.encrypted_dataset_url = "/ipfs/INVALID_IPFS_DATASET_URL".to_string(); - let actual_content = app.download_encrypted_dataset(); - let expected_content = Err(ReplicateStatusCause::PreComputeDatasetDownloadFailed); - assert_eq!(actual_content, expected_content); - } - - #[test] - fn download_encrypted_dataset_failure_with_invalid_dataset_checksum() { - let mut app = get_pre_compute_app(CHAIN_TASK_ID, vec![], ""); - app.pre_compute_args.encrypted_dataset_checksum = "invalid_dataset_checksum".to_string(); - let actual_content = app.download_encrypted_dataset(); - let expected_content = Err(ReplicateStatusCause::PreComputeInvalidDatasetChecksum); - assert_eq!(actual_content, expected_content); - } - // endregion - - // region decrypt_dataset - #[test] - fn decrypt_dataset_success_with_valid_dataset() { - let app = get_pre_compute_app(CHAIN_TASK_ID, vec![], ""); - - let encrypted_data = app.download_encrypted_dataset().unwrap(); - let expected_plain_data = Ok("Some very useful data.".as_bytes().to_vec()); - let actual_plain_data = app.decrypt_dataset(&encrypted_data); - - assert_eq!(actual_plain_data, expected_plain_data); - } - - #[test] - fn decrypt_dataset_failure_with_bad_key() { - let mut app = get_pre_compute_app(CHAIN_TASK_ID, vec![], ""); - app.pre_compute_args.encrypted_dataset_base64_key = "bad_key".to_string(); - let encrypted_data = app.download_encrypted_dataset().unwrap(); - let actual_plain_data = app.decrypt_dataset(&encrypted_data); - - assert_eq!( - actual_plain_data, - Err(ReplicateStatusCause::PreComputeDatasetDecryptionFailed) - ); - } - // endregion - // region save_plain_dataset_file #[test] fn save_plain_dataset_file_success_with_valid_output_dir() { @@ -509,7 +337,7 @@ mod tests { let app = get_pre_compute_app(CHAIN_TASK_ID, vec![], output_path); let plain_dataset = "Some very useful data.".as_bytes().to_vec(); - let saved_dataset = app.save_plain_dataset_file(&plain_dataset); + let saved_dataset = app.save_plain_dataset_file(&plain_dataset, PLAIN_DATA_FILE); assert!(saved_dataset.is_ok()); @@ -532,10 +360,10 @@ mod tests { let temp_dir = TempDir::new().unwrap(); let output_path = temp_dir.path().to_str().unwrap(); - let mut app = get_pre_compute_app(CHAIN_TASK_ID, vec![], output_path); - app.pre_compute_args.plain_dataset_filename = "/some-folder-123/not-found".to_string(); + let app = get_pre_compute_app(CHAIN_TASK_ID, vec![], output_path); let plain_dataset = "Some very useful data.".as_bytes().to_vec(); - let saved_dataset = app.save_plain_dataset_file(&plain_dataset); + let saved_dataset = + app.save_plain_dataset_file(&plain_dataset, "/some-folder-123/not-found"); assert_eq!( saved_dataset, diff --git a/pre-compute/src/compute/pre_compute_args.rs b/pre-compute/src/compute/pre_compute_args.rs index 902df88..781d893 100644 --- a/pre-compute/src/compute/pre_compute_args.rs +++ b/pre-compute/src/compute/pre_compute_args.rs @@ -1,3 +1,4 @@ +use crate::compute::dataset::Dataset; use crate::compute::errors::ReplicateStatusCause; use crate::compute::utils::env_utils::{TeeSessionEnvironmentVariable, get_env_var_or_error}; @@ -11,12 +12,11 @@ pub struct PreComputeArgs { pub output_dir: String, // Dataset related fields pub is_dataset_required: bool, - pub encrypted_dataset_url: String, - pub encrypted_dataset_base64_key: String, - pub encrypted_dataset_checksum: String, - pub plain_dataset_filename: String, // Input files pub input_files: Vec, + // Bulk processing + pub bulk_size: usize, + pub datasets: Vec, } impl PreComputeArgs { @@ -28,20 +28,27 @@ impl PreComputeArgs { /// - `IEXEC_PRE_COMPUTE_OUT`: Output directory path /// - `IEXEC_DATASET_REQUIRED`: Boolean ("true"/"false") indicating dataset requirement /// - `IEXEC_INPUT_FILES_NUMBER`: Number of input files to load + /// - `BULK_SIZE`: Number of bulk datasets (0 means no bulk processing) /// - Required when `IEXEC_DATASET_REQUIRED` = "true": /// - `IEXEC_DATASET_URL`: Encrypted dataset URL /// - `IEXEC_DATASET_KEY`: Base64-encoded dataset encryption key /// - `IEXEC_DATASET_CHECKSUM`: Encrypted dataset checksum /// - `IEXEC_DATASET_FILENAME`: Decrypted dataset filename + /// - Required when `BULK_SIZE` > 0 (for each dataset index from 1 to BULK_SIZE): + /// - `BULK_DATASET_#_URL`: Dataset URL + /// - `BULK_DATASET_#_CHECKSUM`: Dataset checksum + /// - `BULK_DATASET_#_FILENAME`: Dataset filename + /// - `BULK_DATASET_#_KEY`: Dataset decryption key /// - Input file URLs (`IEXEC_INPUT_FILE_URL_1`, `IEXEC_INPUT_FILE_URL_2`, etc.) /// /// # Errors /// Returns `ReplicateStatusCause` error variants for: /// - Missing required environment variables /// - Invalid boolean values in `IEXEC_DATASET_REQUIRED` - /// - Invalid numeric format in `IEXEC_INPUT_FILES_NUMBER` + /// - Invalid numeric format in `IEXEC_INPUT_FILES_NUMBER` or `BULK_SIZE` /// - Missing dataset parameters when required /// - Missing input file URLs + /// - Missing bulk dataset parameters when bulk processing is enabled /// /// # Example /// @@ -66,28 +73,55 @@ impl PreComputeArgs { .parse::() .map_err(|_| ReplicateStatusCause::PreComputeIsDatasetRequiredMissing)?; - let mut encrypted_dataset_url = String::new(); - let mut encrypted_dataset_base64_key = String::new(); - let mut encrypted_dataset_checksum = String::new(); - let mut plain_dataset_filename = String::new(); + // Read bulk size (defaults to 0 if not present for backward compatibility) + let bulk_size_str = std::env::var(TeeSessionEnvironmentVariable::BulkSize.name()) + .unwrap_or("0".to_string()); + let bulk_size = bulk_size_str + .parse::() + .map_err(|_| ReplicateStatusCause::PreComputeIsDatasetRequiredMissing)?; + + let mut datasets = Vec::with_capacity(bulk_size + 1); if is_dataset_required { - encrypted_dataset_url = get_env_var_or_error( + let url = get_env_var_or_error( TeeSessionEnvironmentVariable::IexecDatasetUrl, ReplicateStatusCause::PreComputeDatasetUrlMissing, )?; - encrypted_dataset_base64_key = get_env_var_or_error( + let key = get_env_var_or_error( TeeSessionEnvironmentVariable::IexecDatasetKey, ReplicateStatusCause::PreComputeDatasetKeyMissing, )?; - encrypted_dataset_checksum = get_env_var_or_error( + let checksum = get_env_var_or_error( TeeSessionEnvironmentVariable::IexecDatasetChecksum, ReplicateStatusCause::PreComputeDatasetChecksumMissing, )?; - plain_dataset_filename = get_env_var_or_error( + let filename = get_env_var_or_error( TeeSessionEnvironmentVariable::IexecDatasetFilename, ReplicateStatusCause::PreComputeDatasetFilenameMissing, )?; + datasets.push(Dataset::new(url, checksum, filename, key)); + } + + // Read bulk datasets + for i in 1..=bulk_size { + let url = get_env_var_or_error( + TeeSessionEnvironmentVariable::BulkDatasetUrl(i), + ReplicateStatusCause::PreComputeDatasetUrlMissing, + )?; + let checksum = get_env_var_or_error( + TeeSessionEnvironmentVariable::BulkDatasetChecksum(i), + ReplicateStatusCause::PreComputeDatasetChecksumMissing, + )?; + let filename = get_env_var_or_error( + TeeSessionEnvironmentVariable::BulkDatasetFilename(i), + ReplicateStatusCause::PreComputeDatasetFilenameMissing, + )?; + let key = get_env_var_or_error( + TeeSessionEnvironmentVariable::BulkDatasetKey(i), + ReplicateStatusCause::PreComputeDatasetKeyMissing, + )?; + + datasets.push(Dataset::new(url, checksum, filename, key)); } let input_files_nb_str = get_env_var_or_error( @@ -110,11 +144,9 @@ impl PreComputeArgs { Ok(PreComputeArgs { output_dir, is_dataset_required, - encrypted_dataset_url, - encrypted_dataset_base64_key, - encrypted_dataset_checksum, - plain_dataset_filename, input_files, + bulk_size, + datasets, }) } } @@ -137,6 +169,7 @@ mod tests { vars.insert(IexecPreComputeOut.name(), OUTPUT_DIR.to_string()); vars.insert(IsDatasetRequired.name(), "true".to_string()); vars.insert(IexecInputFilesNumber.name(), "0".to_string()); + vars.insert(BulkSize.name(), "0".to_string()); // Default to no bulk processing vars } @@ -162,6 +195,26 @@ mod tests { vars } + // TODO: Collect all errors instead of propagating immediately, and return the list of errors + fn setup_bulk_dataset_env_vars(count: usize) -> HashMap { + let mut vars = HashMap::new(); + vars.insert(BulkSize.name(), count.to_string()); + + for i in 1..=count { + vars.insert( + BulkDatasetUrl(i).name(), + format!("https://bulk-dataset-{i}.bin"), + ); + vars.insert(BulkDatasetChecksum(i).name(), format!("0x{i}23checksum")); + vars.insert( + BulkDatasetFilename(i).name(), + format!("bulk-dataset-{i}.txt"), + ); + vars.insert(BulkDatasetKey(i).name(), format!("bulkKey{i}23")); + } + vars + } + fn to_temp_env_vars(map: HashMap) -> Vec<(String, Option)> { map.into_iter().map(|(k, v)| (k, Some(v))).collect() } @@ -180,12 +233,10 @@ mod tests { assert_eq!(args.output_dir, OUTPUT_DIR); assert!(!args.is_dataset_required); - assert_eq!(args.encrypted_dataset_url, ""); - assert_eq!(args.encrypted_dataset_base64_key, ""); - assert_eq!(args.encrypted_dataset_checksum, ""); - assert_eq!(args.plain_dataset_filename, ""); assert_eq!(args.input_files.len(), 1); assert_eq!(args.input_files[0], "https://input-1.txt"); + assert_eq!(args.bulk_size, 0); + assert_eq!(args.datasets.len(), 0); }); } @@ -204,14 +255,13 @@ mod tests { assert_eq!(args.output_dir, OUTPUT_DIR); assert!(args.is_dataset_required); - assert_eq!(args.encrypted_dataset_url, DATASET_URL.to_string()); - assert_eq!(args.encrypted_dataset_base64_key, DATASET_KEY.to_string()); - assert_eq!( - args.encrypted_dataset_checksum, - DATASET_CHECKSUM.to_string() - ); - assert_eq!(args.plain_dataset_filename, DATASET_FILENAME.to_string()); + assert_eq!(args.datasets[0].url, DATASET_URL.to_string()); + assert_eq!(args.datasets[0].key, DATASET_KEY.to_string()); + assert_eq!(args.datasets[0].checksum, DATASET_CHECKSUM.to_string()); + assert_eq!(args.datasets[0].filename, DATASET_FILENAME.to_string()); assert_eq!(args.input_files.len(), 0); + assert_eq!(args.bulk_size, 0); + assert_eq!(args.datasets.len(), 1); }); } @@ -231,14 +281,12 @@ mod tests { assert_eq!(args.output_dir, OUTPUT_DIR); assert!(!args.is_dataset_required); - assert_eq!(args.encrypted_dataset_url, ""); - assert_eq!(args.encrypted_dataset_base64_key, ""); - assert_eq!(args.encrypted_dataset_checksum, ""); - assert_eq!(args.plain_dataset_filename, ""); assert_eq!(args.input_files.len(), 3); assert_eq!(args.input_files[0], "https://input-1.txt"); assert_eq!(args.input_files[1], "https://input-2.txt"); assert_eq!(args.input_files[2], "https://input-3.txt"); + assert_eq!(args.bulk_size, 0); + assert_eq!(args.datasets.len(), 0); }); } // endregion @@ -295,6 +343,171 @@ mod tests { } // endregion + // region bulk processing tests + #[test] + fn read_args_succeeds_with_bulk_datasets() { + let mut env_vars = setup_basic_env_vars(); + env_vars.insert(IsDatasetRequired.name(), "false".to_string()); + env_vars.extend(setup_input_files_env_vars(0)); + env_vars.extend(setup_bulk_dataset_env_vars(3)); + + temp_env::with_vars(to_temp_env_vars(env_vars), || { + let result = PreComputeArgs::read_args(); + + assert!(result.is_ok()); + let args = result.unwrap(); + + assert_eq!(args.output_dir, OUTPUT_DIR); + assert!(!args.is_dataset_required); + assert_eq!(args.bulk_size, 3); + assert_eq!(args.datasets.len(), 3); + assert_eq!(args.input_files.len(), 0); + + // Check first bulk dataset + assert_eq!(args.datasets[0].url, "https://bulk-dataset-1.bin"); + assert_eq!(args.datasets[0].checksum, "0x123checksum"); + assert_eq!(args.datasets[0].filename, "bulk-dataset-1.txt"); + assert_eq!(args.datasets[0].key, "bulkKey123"); + + // Check second bulk dataset + assert_eq!(args.datasets[1].url, "https://bulk-dataset-2.bin"); + assert_eq!(args.datasets[1].checksum, "0x223checksum"); + assert_eq!(args.datasets[1].filename, "bulk-dataset-2.txt"); + assert_eq!(args.datasets[1].key, "bulkKey223"); + + // Check third bulk dataset + assert_eq!(args.datasets[2].url, "https://bulk-dataset-3.bin"); + assert_eq!(args.datasets[2].checksum, "0x323checksum"); + assert_eq!(args.datasets[2].filename, "bulk-dataset-3.txt"); + assert_eq!(args.datasets[2].key, "bulkKey323"); + }); + } + + #[test] + fn read_args_succeeds_with_both_dataset_and_bulk_datasets() { + let mut env_vars = setup_basic_env_vars(); + env_vars.extend(setup_dataset_env_vars()); + env_vars.extend(setup_input_files_env_vars(0)); + env_vars.extend(setup_bulk_dataset_env_vars(2)); + + temp_env::with_vars(to_temp_env_vars(env_vars), || { + let result = PreComputeArgs::read_args(); + + assert!(result.is_ok()); + let args = result.unwrap(); + + assert_eq!(args.output_dir, OUTPUT_DIR); + assert!(args.is_dataset_required); + assert_eq!(args.bulk_size, 2); + assert_eq!(args.datasets.len(), 3); // 1 regular + 2 bulk datasets + assert_eq!(args.input_files.len(), 0); + + // Check regular dataset (first in list) + assert_eq!(args.datasets[0].url, DATASET_URL); + assert_eq!(args.datasets[0].checksum, DATASET_CHECKSUM); + assert_eq!(args.datasets[0].filename, DATASET_FILENAME); + assert_eq!(args.datasets[0].key, DATASET_KEY); + + // Check bulk datasets + assert_eq!(args.datasets[1].url, "https://bulk-dataset-1.bin"); + assert_eq!(args.datasets[2].url, "https://bulk-dataset-2.bin"); + }); + } + + #[test] + fn read_args_fails_when_invalid_bulk_size_format() { + let mut env_vars = setup_basic_env_vars(); + env_vars.insert(IsDatasetRequired.name(), "false".to_string()); + env_vars.insert(BulkSize.name(), "not-a-number".to_string()); + env_vars.extend(setup_input_files_env_vars(0)); + + temp_env::with_vars(to_temp_env_vars(env_vars), || { + let result = PreComputeArgs::read_args(); + assert!(result.is_err()); + assert_eq!( + result.unwrap_err(), + ReplicateStatusCause::PreComputeIsDatasetRequiredMissing + ); + }); + } + + #[test] + fn read_args_fails_when_bulk_dataset_url_missing() { + let mut env_vars = setup_basic_env_vars(); + env_vars.insert(IsDatasetRequired.name(), "false".to_string()); + env_vars.extend(setup_input_files_env_vars(0)); + env_vars.extend(setup_bulk_dataset_env_vars(2)); + // Remove one of the bulk dataset URLs + env_vars.remove(&BulkDatasetUrl(1).name()); + + temp_env::with_vars(to_temp_env_vars(env_vars), || { + let result = PreComputeArgs::read_args(); + assert!(result.is_err()); + assert_eq!( + result.unwrap_err(), + ReplicateStatusCause::PreComputeDatasetUrlMissing + ); + }); + } + + #[test] + fn read_args_fails_when_bulk_dataset_checksum_missing() { + let mut env_vars = setup_basic_env_vars(); + env_vars.insert(IsDatasetRequired.name(), "false".to_string()); + env_vars.extend(setup_input_files_env_vars(0)); + env_vars.extend(setup_bulk_dataset_env_vars(2)); + // Remove one of the bulk dataset checksums + env_vars.remove(&BulkDatasetChecksum(2).name()); + + temp_env::with_vars(to_temp_env_vars(env_vars), || { + let result = PreComputeArgs::read_args(); + assert!(result.is_err()); + assert_eq!( + result.unwrap_err(), + ReplicateStatusCause::PreComputeDatasetChecksumMissing + ); + }); + } + + #[test] + fn read_args_fails_when_bulk_dataset_filename_missing() { + let mut env_vars = setup_basic_env_vars(); + env_vars.insert(IsDatasetRequired.name(), "false".to_string()); + env_vars.extend(setup_input_files_env_vars(0)); + env_vars.extend(setup_bulk_dataset_env_vars(3)); + // Remove one of the bulk dataset filenames + env_vars.remove(&BulkDatasetFilename(2).name()); + + temp_env::with_vars(to_temp_env_vars(env_vars), || { + let result = PreComputeArgs::read_args(); + assert!(result.is_err()); + assert_eq!( + result.unwrap_err(), + ReplicateStatusCause::PreComputeDatasetFilenameMissing + ); + }); + } + + #[test] + fn read_args_fails_when_bulk_dataset_key_missing() { + let mut env_vars = setup_basic_env_vars(); + env_vars.insert(IsDatasetRequired.name(), "false".to_string()); + env_vars.extend(setup_input_files_env_vars(0)); + env_vars.extend(setup_bulk_dataset_env_vars(2)); + // Remove one of the bulk dataset keys + env_vars.remove(&BulkDatasetKey(1).name()); + + temp_env::with_vars(to_temp_env_vars(env_vars), || { + let result = PreComputeArgs::read_args(); + assert!(result.is_err()); + assert_eq!( + result.unwrap_err(), + ReplicateStatusCause::PreComputeDatasetKeyMissing + ); + }); + } + // endregion + // region dataset environment variables #[test] fn read_args_fails_when_dataset_env_var_missing() { diff --git a/pre-compute/src/compute/utils/env_utils.rs b/pre-compute/src/compute/utils/env_utils.rs index 12a4120..972b1f4 100644 --- a/pre-compute/src/compute/utils/env_utils.rs +++ b/pre-compute/src/compute/utils/env_utils.rs @@ -2,6 +2,11 @@ use crate::compute::errors::ReplicateStatusCause; use std::env; pub enum TeeSessionEnvironmentVariable { + BulkSize, + BulkDatasetUrl(usize), + BulkDatasetChecksum(usize), + BulkDatasetFilename(usize), + BulkDatasetKey(usize), IexecDatasetChecksum, IexecDatasetFilename, IexecDatasetKey, @@ -19,6 +24,19 @@ pub enum TeeSessionEnvironmentVariable { impl TeeSessionEnvironmentVariable { pub fn name(&self) -> String { match self { + TeeSessionEnvironmentVariable::BulkSize => "BULK_SIZE".to_string(), + TeeSessionEnvironmentVariable::BulkDatasetUrl(index) => { + format!("BULK_DATASET_{index}_URL") + } + TeeSessionEnvironmentVariable::BulkDatasetChecksum(index) => { + format!("BULK_DATASET_{index}_CHECKSUM") + } + TeeSessionEnvironmentVariable::BulkDatasetFilename(index) => { + format!("BULK_DATASET_{index}_FILENAME") + } + TeeSessionEnvironmentVariable::BulkDatasetKey(index) => { + format!("BULK_DATASET_{index}_KEY") + } TeeSessionEnvironmentVariable::IexecDatasetChecksum => { "IEXEC_DATASET_CHECKSUM".to_string() }