diff --git a/crates/experimentation_client/src/interface.rs b/crates/experimentation_client/src/interface.rs index 3ec05345..0d5161a8 100644 --- a/crates/experimentation_client/src/interface.rs +++ b/crates/experimentation_client/src/interface.rs @@ -182,16 +182,17 @@ pub extern "C" fn get_applicable_variant( }, Err(err) => return error_block(err), }; - // println!("Fetching variantIds"); let local = task::LocalSet::new(); - let variants = local.block_on(&Runtime::new().unwrap(), unsafe { + let variants_result = local.block_on(&Runtime::new().unwrap(), unsafe { (*client).get_applicable_variant(&context, toss as i8) }); - // println!("variantIds: {:?}", variants); - match serde_json::to_string::>(&variants) { - Ok(result) => rstring_to_cstring(result).into_raw(), - Err(err) => error_block(err.to_string()), - } + variants_result + .map(|result| { + serde_json::to_string(&result) + .map(|json| rstring_to_cstring(json).into_raw()) + .unwrap_or_else(|err| error_block(err.to_string())) + }) + .unwrap_or_else(|err| error_block(err.to_string())) } #[no_mangle] @@ -209,16 +210,15 @@ pub extern "C" fn get_satisfied_experiments( let local = task::LocalSet::new(); let experiments = local.block_on(&Runtime::new().unwrap(), unsafe { - (*client).get_satisfied_experiments(&context) + (*client).get_satisfied_experiments(&context, None) }); let experiments = match serde_json::to_value(experiments) { Ok(value) => value, Err(err) => return error_block(err.to_string()), }; - match serde_json::to_string(&experiments) { - Ok(result) => rstring_to_cstring(result).into_raw(), - Err(err) => error_block(err.to_string()), - } + serde_json::to_string(&experiments) + .map(|exp| rstring_to_cstring(exp).into_raw()) + .unwrap_or_else(|err| error_block(err.to_string())) } #[no_mangle] diff --git a/crates/experimentation_client/src/lib.rs b/crates/experimentation_client/src/lib.rs index a7114302..90741d67 100644 --- a/crates/experimentation_client/src/lib.rs +++ b/crates/experimentation_client/src/lib.rs @@ -1,16 +1,21 @@ mod interface; mod types; -use std::{collections::HashMap, sync::Arc}; +mod utils; +use std::{ + collections::{HashMap, HashSet}, + sync::Arc, +}; use chrono::{DateTime, TimeZone, Utc}; use derive_more::{Deref, DerefMut}; -use serde_json::Value; +use serde_json::{Map, Value}; use tokio::{ sync::RwLock, time::{self, Duration}, }; pub use types::{Config, Experiment, Experiments, Variants}; use types::{ExperimentStore, ListExperimentsResponse, Variant, VariantType}; +use utils::MapError; #[derive(Clone, Debug)] pub struct Client { @@ -52,7 +57,7 @@ impl Client { self.client_config.tenant.to_string(), ) .await - .unwrap(); + .unwrap_or(HashMap::new()); let mut exp_store = self.experiments.write().await; for (exp_id, experiment) in experiments.into_iter() { @@ -69,34 +74,91 @@ impl Client { } } - pub async fn get_applicable_variant(&self, context: &Value, toss: i8) -> Vec { - let experiments: Experiments = self.get_satisfied_experiments(context).await; + pub async fn get_applicable_variant( + &self, + context: &Value, + toss: i8, + ) -> Result, String> { + let experiments: Experiments = + self.get_satisfied_experiments(context, None).await?; let mut variants: Vec = Vec::new(); for exp in experiments { if let Some(v) = - self.decide_variant(exp.traffic_percentage, exp.variants, toss) + self.decide_variant(exp.traffic_percentage, exp.variants, toss)? { variants.push(v.id) } } - variants + Ok(variants) } - pub async fn get_satisfied_experiments(&self, context: &Value) -> Experiments { + pub async fn get_satisfied_experiments( + &self, + context: &Value, + prefix: Option>, + ) -> Result { let running_experiments = self.experiments.read().await; - running_experiments + let filtered_running_experiments = running_experiments .iter() .filter(|(_, exp)| { jsonlogic::apply(&exp.context, context) == Ok(Value::Bool(true)) }) .map(|(_, exp)| exp.clone()) - .collect::() + .collect::(); + + if let Some(prefix) = prefix { + let prefix_list: HashSet<&str> = prefix.iter().map(|s| s.as_str()).collect(); + + let prefix_filtered_running_experiments: Vec = + filtered_running_experiments + .into_iter() + .filter_map(|experiment| { + let variants: Vec = experiment + .variants + .into_iter() + .filter_map(|mut variant| { + let overrides_map: Map = + serde_json::from_value(variant.overrides.clone()) + .ok()?; + let filtered_override: Map = overrides_map + .into_iter() + .filter(|(key, _)| { + prefix_list + .iter() + .any(|prefix_str| key.starts_with(prefix_str)) + }) + .collect(); + if filtered_override.is_empty() { + return None; // Skip this variant + } + + variant.overrides = + serde_json::to_value(filtered_override).ok()?; + Some(variant) + }) + .collect(); + + if !variants.is_empty() { + Some(Experiment { + variants, + ..experiment + }) + } else { + None // Skip this experiment + } + }) + .collect(); + + return Ok(prefix_filtered_running_experiments); + } + + Ok(filtered_running_experiments) } - pub async fn get_running_experiments(&self) -> Experiments { + pub async fn get_running_experiments(&self) -> Result { let running_experiments = self.experiments.read().await; let experiments: Experiments = running_experiments.values().cloned().collect(); - experiments + Ok(experiments) } // decide which variant to return among all applicable experiments @@ -105,24 +167,28 @@ impl Client { traffic: u8, applicable_variants: Variants, toss: i8, - ) -> Option { + ) -> Result, String> { if toss < 0 { for variant in applicable_variants.iter() { if variant.variant_type == VariantType::EXPERIMENTAL { - return Some(variant.clone()); + return Ok(Some(variant.clone())); } } } let variant_count = applicable_variants.len() as u8; let range = (traffic * variant_count) as i32; if (toss as i32) >= range { - return None; + return Ok(None); } let buckets = (1..=variant_count) .map(|i| (traffic * i) as i8) .collect::>(); - let index = buckets.into_iter().position(|x| toss < x); - applicable_variants.get(index.unwrap()).map(Variant::clone) + let index = buckets + .into_iter() + .position(|x| toss < x) + .ok_or_else(|| "Unable to fetch variant's index".to_string()) + .map_err_to_string()?; + Ok(applicable_variants.get(index).map(Variant::clone)) } } @@ -145,13 +211,12 @@ async fn get_experiments( .header("x-tenant", tenant.to_string()) .send() .await - .unwrap() + .map_err_to_string()? .json::() .await - .unwrap_or_default(); + .map_err_to_string()?; let experiments = list_experiments_response.data; - // println!("got these running experiments: {:?}", running_experiments); for experiment in experiments.into_iter() { curr_exp_store.insert(experiment.id.to_string(), experiment); diff --git a/crates/experimentation_client/src/utils/mod.rs b/crates/experimentation_client/src/utils/mod.rs new file mode 100644 index 00000000..e00789da --- /dev/null +++ b/crates/experimentation_client/src/utils/mod.rs @@ -0,0 +1,14 @@ +use std::fmt; + +pub trait MapError { + fn map_err_to_string(self) -> Result; +} + +impl MapError for Result +where + E: fmt::Display, +{ + fn map_err_to_string(self) -> Result { + self.map_err(|e| e.to_string()) + } +}