diff --git a/mixpanel/__init__.py b/mixpanel/__init__.py index decc461..abc9aca 100644 --- a/mixpanel/__init__.py +++ b/mixpanel/__init__.py @@ -30,7 +30,7 @@ from .flags.remote_feature_flags import RemoteFeatureFlagsProvider from .flags.types import LocalFlagsConfig, RemoteFlagsConfig -__version__ = '5.0.0b2' +__version__ = '5.0.0' logger = logging.getLogger(__name__) diff --git a/mixpanel/flags/local_feature_flags.py b/mixpanel/flags/local_feature_flags.py index 4b70132..5bc441e 100644 --- a/mixpanel/flags/local_feature_flags.py +++ b/mixpanel/flags/local_feature_flags.py @@ -79,7 +79,7 @@ def start_polling_for_definitions(self): ) self._sync_polling_task.start() else: - logging.warning("A polling task is already running") + logger.warning("A polling task is already running") def stop_polling_for_definitions(self): """ @@ -90,7 +90,7 @@ def stop_polling_for_definitions(self): self._sync_stop_event.set() self._sync_polling_task = None else: - logging.info("There is no polling task to cancel.") + logger.info("There is no polling task to cancel.") async def astart_polling_for_definitions(self): """ @@ -105,7 +105,7 @@ async def astart_polling_for_definitions(self): self._astart_continuous_polling() ) else: - logging.error("A polling task is already running") + logger.error("A polling task is already running") async def astop_polling_for_definitions(self): """ @@ -115,10 +115,10 @@ async def astop_polling_for_definitions(self): self._async_polling_task.cancel() self._async_polling_task = None else: - logging.info("There is no polling task to cancel.") + logger.info("There is no polling task to cancel.") async def _astart_continuous_polling(self): - logging.info( + logger.info( f"Initialized async polling for flag definition updates every '{self._config.polling_interval_in_seconds}' seconds" ) try: @@ -126,10 +126,10 @@ async def _astart_continuous_polling(self): await asyncio.sleep(self._config.polling_interval_in_seconds) await self._afetch_flag_definitions() except asyncio.CancelledError: - logging.info("Async polling was cancelled") + logger.info("Async polling was cancelled") def _start_continuous_polling(self): - logging.info( + logger.info( f"Initialized sync polling for flag definition updates every '{self._config.polling_interval_in_seconds}' seconds" ) while not self._sync_stop_event.is_set(): @@ -146,6 +146,22 @@ def are_flags_ready(self) -> bool: """ return self._are_flags_ready + def get_all_variants(self, context: Dict[str, Any]) -> Dict[str, SelectedVariant]: + """ + Gets the selected variant for all feature flags that the current user context is in the rollout for. + Exposure events are not automatically tracked when this method is used. + :param Dict[str, Any] context: The user context to evaluate against the feature flags + """ + variants: Dict[str, SelectedVariant] = {} + fallback = SelectedVariant(variant_key=None, variant_value=None) + + for flag_key in self._flag_definitions.keys(): + variant = self.get_variant(flag_key, fallback, context, report_exposure=False) + if variant.variant_key is not None: + variants[flag_key] = variant + + return variants + def get_variant_value( self, flag_key: str, fallback_value: Any, context: Dict[str, Any] ) -> Any: @@ -206,16 +222,28 @@ def get_variant( flag_definition, context_value, flag_key, rollout ) - if report_exposure and selected_variant is not None: - end_time = time.perf_counter() - self._track_exposure(flag_key, selected_variant, end_time - start_time, context) + if selected_variant is not None: + if report_exposure: + end_time = time.perf_counter() + self._track_exposure(flag_key, selected_variant, context, end_time - start_time) return selected_variant - logger.info( + logger.debug( f"{flag_definition.context} context {context_value} not eligible for any rollout for flag: {flag_key}" ) return fallback_value + def track_exposure_event(self, flag_key: str, variant: SelectedVariant, context: Dict[str, Any]): + """ + Manually tracks a feature flagging exposure event to Mixpanel. + This is intended to provide flexibility for when individual exposure events are reported when using `get_all_variants` for the user at once with exposure event reporting + + :param str flag_key: The key of the feature flag + :param SelectedVariant variant: The selected variant for the feature flag + :param Dict[str, Any] context: The user context used to evaluate the feature flag + """ + self._track_exposure(flag_key, variant, context) + def _get_variant_override_for_test_user( self, flag_definition: ExperimentationFlag, context: Dict[str, Any] ) -> Optional[SelectedVariant]: @@ -244,10 +272,9 @@ def _get_assigned_variant( ): return variant - - hash_input = str(context_value) + flag_name - - variant_hash = normalized_hash(hash_input, "variant") + stored_salt = flag_definition.hash_salt if flag_definition.hash_salt is not None else "" + salt = flag_name + stored_salt + "variant" + variant_hash = normalized_hash(str(context_value), salt) variants = [variant.model_copy(deep=True) for variant in flag_definition.ruleset.variants] if rollout.variant_splits: @@ -275,13 +302,16 @@ def _get_assigned_rollout( context_value: Any, context: Dict[str, Any], ) -> Optional[Rollout]: - hash_input = str(context_value) + flag_definition.key + for index, rollout in enumerate(flag_definition.ruleset.rollout): + salt = None + if flag_definition.hash_salt is not None: + salt = flag_definition.key + flag_definition.hash_salt + str(index) + else: + salt = flag_definition.key + "rollout" - rollout_hash = normalized_hash(hash_input, "rollout") + rollout_hash = normalized_hash(str(context_value), salt) - for rollout in flag_definition.ruleset.rollout: - if ( - rollout_hash < rollout.rollout_percentage + if (rollout_hash < rollout.rollout_percentage and self._is_runtime_evaluation_satisfied(rollout, context) ): return rollout @@ -352,7 +382,7 @@ def _handle_response( self, response: httpx.Response, start_time: datetime, end_time: datetime ) -> None: request_duration: timedelta = end_time - start_time - logging.info( + logger.debug( f"Request started at '{start_time.isoformat()}', completed at '{end_time.isoformat()}', duration: '{request_duration.total_seconds():.3f}s'" ) @@ -378,8 +408,8 @@ def _track_exposure( self, flag_key: str, variant: SelectedVariant, - latency_in_seconds: float, context: Dict[str, Any], + latency_in_seconds: Optional[float]=None, ): if distinct_id := context.get("distinct_id"): properties = { @@ -387,15 +417,17 @@ def _track_exposure( "Variant name": variant.variant_key, "$experiment_type": "feature_flag", "Flag evaluation mode": "local", - "Variant fetch latency (ms)": latency_in_seconds * 1000, "$experiment_id": variant.experiment_id, "$is_experiment_active": variant.is_experiment_active, "$is_qa_tester": variant.is_qa_tester, } + if latency_in_seconds is not None: + properties["Variant fetch latency (ms)"] = latency_in_seconds * 1000 + self._tracker(distinct_id, EXPOSURE_EVENT, properties) else: - logging.error( + logger.error( "Cannot track exposure event without a distinct_id in the context" ) @@ -406,11 +438,11 @@ def __enter__(self): return self async def __aexit__(self, exc_type, exc_val, exc_tb): - logging.info("Exiting the LocalFeatureFlagsProvider and cleaning up resources") + logger.info("Exiting the LocalFeatureFlagsProvider and cleaning up resources") await self.astop_polling_for_definitions() await self._async_client.aclose() def __exit__(self, exc_type, exc_val, exc_tb): - logging.info("Exiting the LocalFeatureFlagsProvider and cleaning up resources") + logger.info("Exiting the LocalFeatureFlagsProvider and cleaning up resources") self.stop_polling_for_definitions() self._sync_client.close() diff --git a/mixpanel/flags/remote_feature_flags.py b/mixpanel/flags/remote_feature_flags.py index af62c74..8d265ae 100644 --- a/mixpanel/flags/remote_feature_flags.py +++ b/mixpanel/flags/remote_feature_flags.py @@ -4,7 +4,7 @@ import urllib.parse import asyncio from datetime import datetime -from typing import Dict, Any, Callable +from typing import Dict, Any, Callable, Tuple, Optional from asgiref.sync import sync_to_async from .types import RemoteFlagsConfig, SelectedVariant, RemoteFlagsResponse @@ -38,6 +38,26 @@ def __init__( self._sync_client: httpx.Client = httpx.Client(**httpx_client_parameters) self._request_params_base = prepare_common_query_params(self._token, version) + async def aget_all_variants(self, context: Dict[str, Any]) -> Optional[Dict[str, SelectedVariant]]: + """ + Asynchronously gets all feature flag variants for the current user context from remote server. + :param Dict[str, Any] context: Context dictionary containing user attributes and rollout context + :return: A dictionary mapping flag keys to their selected variants, or None if the call fails + """ + flags: Optional[Dict[str, SelectedVariant]] = None + try: + params = self._prepare_query_params(context) + start_time = datetime.now() + headers = {"traceparent": generate_traceparent()} + response = await self._async_client.get(self.FLAGS_URL_PATH, params=params, headers=headers) + end_time = datetime.now() + self._instrument_call(start_time, end_time) + flags = self._handle_response(response) + except Exception: + logger.exception(f"Failed to get remote variants") + + return flags + async def aget_variant_value( self, flag_key: str, fallback_value: Any, context: Dict[str, Any] ) -> Any: @@ -54,7 +74,7 @@ async def aget_variant_value( return variant.variant_value async def aget_variant( - self, flag_key: str, fallback_value: SelectedVariant, context: Dict[str, Any] + self, flag_key: str, fallback_value: SelectedVariant, context: Dict[str, Any], reportExposure: bool = True ) -> SelectedVariant: """ Asynchronously gets the selected variant of a feature flag variant for the current user context from remote server. @@ -62,19 +82,19 @@ async def aget_variant( :param str flag_key: The key of the feature flag to evaluate :param SelectedVariant fallback_value: The default variant to return if evaluation fails :param Dict[str, Any] context: Context dictionary containing user attributes and rollout context + :param bool reportExposure: Whether to report an exposure event if a variant is successfully retrieved """ try: - params = self._prepare_query_params(flag_key, context) + params = self._prepare_query_params(context, flag_key) start_time = datetime.now() headers = {"traceparent": generate_traceparent()} response = await self._async_client.get(self.FLAGS_URL_PATH, params=params, headers=headers) end_time = datetime.now() self._instrument_call(start_time, end_time) - selected_variant, is_fallback = self._handle_response( - flag_key, fallback_value, response - ) + flags = self._handle_response(response) + selected_variant, is_fallback = self._lookup_flag_in_response(flag_key, flags, fallback_value) - if not is_fallback and (distinct_id := context.get("distinct_id")): + if not is_fallback and reportExposure and (distinct_id := context.get("distinct_id")): properties = self._build_tracking_properties( flag_key, selected_variant, start_time, end_time ) @@ -86,7 +106,7 @@ async def aget_variant( return selected_variant except Exception: - logging.exception(f"Failed to get remote variant for flag '{flag_key}'") + logger.exception(f"Failed to get remote variant for flag '{flag_key}'") return fallback_value async def ais_enabled(self, flag_key: str, context: Dict[str, Any]) -> bool: @@ -99,6 +119,51 @@ async def ais_enabled(self, flag_key: str, context: Dict[str, Any]) -> bool: variant_value = await self.aget_variant_value(flag_key, False, context) return variant_value == True + async def atrack_exposure_event( + self, + flag_key: str, + variant: SelectedVariant, + context: Dict[str, Any]): + """ + Manually tracks a feature flagging exposure event asynchronously to Mixpanel. + This is intended to provide flexibility for when individual exposure events are reported when using `get_all_variants` for the user at once with exposure event reporting + + :param str flag_key: The key of the feature flag + :param SelectedVariant variant: The selected variant for the feature flag + :param Dict[str, Any] context: The user context used to evaluate the feature flag + """ + if (distinct_id := context.get("distinct_id")): + properties = self._build_tracking_properties(flag_key, variant) + + await sync_to_async(self._tracker, thread_sensitive=False)( + distinct_id, EXPOSURE_EVENT, properties + ) + else: + logger.error( + "Cannot track exposure event without a distinct_id in the context" + ) + + + def get_all_variants(self, context: Dict[str, Any]) -> Optional[Dict[str, SelectedVariant]]: + """ + Synchronously gets all feature flag variants for the current user context from remote server. + :param Dict[str, Any] context: Context dictionary containing user attributes and rollout context + :return: A dictionary mapping flag keys to their selected variants, or None if the call fails + """ + flags: Optional[Dict[str, SelectedVariant]] = None + try: + params = self._prepare_query_params(context) + start_time = datetime.now() + headers = {"traceparent": generate_traceparent()} + response = self._sync_client.get(self.FLAGS_URL_PATH, params=params, headers=headers) + end_time = datetime.now() + self._instrument_call(start_time, end_time) + flags = self._handle_response(response) + except Exception: + logger.exception(f"Failed to get remote variants") + + return flags + def get_variant_value( self, flag_key: str, fallback_value: Any, context: Dict[str, Any] ) -> Any: @@ -115,7 +180,7 @@ def get_variant_value( return variant.variant_value def get_variant( - self, flag_key: str, fallback_value: SelectedVariant, context: Dict[str, Any] + self, flag_key: str, fallback_value: SelectedVariant, context: Dict[str, Any], reportExposure: bool = True ) -> SelectedVariant: """ Synchronously gets the selected variant for a feature flag from remote server. @@ -123,19 +188,20 @@ def get_variant( :param str flag_key: The key of the feature flag to evaluate :param SelectedVariant fallback_value: The default variant to return if evaluation fails :param Dict[str, Any] context: Context dictionary containing user attributes and rollout context + :param bool reportExposure: Whether to report an exposure event if a variant is successfully retrieved """ try: - params = self._prepare_query_params(flag_key, context) + params = self._prepare_query_params(context, flag_key) start_time = datetime.now() headers = {"traceparent": generate_traceparent()} response = self._sync_client.get(self.FLAGS_URL_PATH, params=params, headers=headers) end_time = datetime.now() self._instrument_call(start_time, end_time) - selected_variant, is_fallback = self._handle_response( - flag_key, fallback_value, response - ) - if not is_fallback and (distinct_id := context.get("distinct_id")): + flags = self._handle_response(response) + selected_variant, is_fallback = self._lookup_flag_in_response(flag_key, flags, fallback_value) + + if not is_fallback and reportExposure and (distinct_id := context.get("distinct_id")): properties = self._build_tracking_properties( flag_key, selected_variant, start_time, end_time ) @@ -156,20 +222,43 @@ def is_enabled(self, flag_key: str, context: Dict[str, Any]) -> bool: variant_value = self.get_variant_value(flag_key, False, context) return variant_value == True + def track_exposure_event( + self, + flag_key: str, + variant: SelectedVariant, + context: Dict[str, Any]): + """ + Manually tracks a feature flagging exposure event synchronously to Mixpanel. + This is intended to provide flexibility for when individual exposure events are reported when using `get_all_variants` for the user at once with exposure event reporting + + :param str flag_key: The key of the feature flag + :param SelectedVariant variant: The selected variant for the feature flag + :param Dict[str, Any] context: The user context used to evaluate the feature flag + """ + if (distinct_id := context.get("distinct_id")): + properties = self._build_tracking_properties(flag_key, variant) + self._tracker(distinct_id, EXPOSURE_EVENT, properties) + else: + logging.error( + "Cannot track exposure event without a distinct_id in the context" + ) + def _prepare_query_params( - self, flag_key: str, context: Dict[str, Any] + self, context: Dict[str, Any], flag_key: Optional[str] = None ) -> Dict[str, str]: params = self._request_params_base.copy() context_json = json.dumps(context).encode("utf-8") url_encoded_context = urllib.parse.quote(context_json) - params.update({"flag_key": flag_key, "context": url_encoded_context}) + params["context"] = url_encoded_context + if flag_key is not None: + params["flag_key"] = flag_key return params def _instrument_call(self, start_time: datetime, end_time: datetime) -> None: request_duration = end_time - start_time formatted_start_time = start_time.isoformat() formatted_end_time = end_time.isoformat() - logging.info( + logging.debug( f"Request started at '{formatted_start_time}', completed at '{formatted_end_time}', duration: '{request_duration.total_seconds():.3f}s'" ) @@ -177,38 +266,44 @@ def _build_tracking_properties( self, flag_key: str, variant: SelectedVariant, - start_time: datetime, - end_time: datetime, + start_time: Optional[datetime] = None, + end_time: Optional[datetime] = None, ) -> Dict[str, Any]: - request_duration = end_time - start_time - formatted_start_time = start_time.isoformat() - formatted_end_time = end_time.isoformat() - - return { + tracking_properties: Dict[str, Any] = { "Experiment name": flag_key, "Variant name": variant.variant_key, "$experiment_type": "feature_flag", "Flag evaluation mode": "remote", - "Variant fetch start time": formatted_start_time, - "Variant fetch complete time": formatted_end_time, - "Variant fetch latency (ms)": request_duration.total_seconds() * 1000, } - def _handle_response( - self, flag_key: str, fallback_value: SelectedVariant, response: httpx.Response - ) -> tuple[SelectedVariant, bool]: - response.raise_for_status() + if start_time is not None and end_time is not None: + request_duration = end_time - start_time + formatted_start_time = start_time.isoformat() + formatted_end_time = end_time.isoformat() + + tracking_properties.update({ + "Variant fetch start time": formatted_start_time, + "Variant fetch complete time": formatted_end_time, + "Variant fetch latency (ms)": request_duration.total_seconds() * 1000, + }) + return tracking_properties + + def _handle_response(self, response: httpx.Response) -> Dict[str, SelectedVariant]: + response.raise_for_status() flags_response = RemoteFlagsResponse.model_validate(response.json()) + return flags_response.flags - if flag_key in flags_response.flags: - return flags_response.flags[flag_key], False + def _lookup_flag_in_response(self, flag_key: str, flags: Dict[str, SelectedVariant], fallback_value: SelectedVariant) -> Tuple[SelectedVariant, bool]: + if flag_key in flags: + return flags[flag_key], False else: - logging.warning( + logging.debug( f"Flag '{flag_key}' not found in remote response. Returning fallback, '{fallback_value}'" ) return fallback_value, True + def __enter__(self): return self diff --git a/mixpanel/flags/test_local_feature_flags.py b/mixpanel/flags/test_local_feature_flags.py index dba1d20..fed3f57 100644 --- a/mixpanel/flags/test_local_feature_flags.py +++ b/mixpanel/flags/test_local_feature_flags.py @@ -6,7 +6,7 @@ from unittest.mock import Mock, patch from typing import Dict, Optional, List from itertools import chain, repeat -from .types import LocalFlagsConfig, ExperimentationFlag, RuleSet, Variant, Rollout, FlagTestUsers, ExperimentationFlags, VariantOverride +from .types import LocalFlagsConfig, ExperimentationFlag, RuleSet, Variant, Rollout, FlagTestUsers, ExperimentationFlags, VariantOverride, SelectedVariant from .local_feature_flags import LocalFeatureFlagsProvider @@ -20,8 +20,8 @@ def create_test_flag( test_users: Optional[Dict[str, str]] = None, experiment_id: Optional[str] = None, is_experiment_active: Optional[bool] = None, - variant_splits: Optional[Dict[str, float]] = None) -> ExperimentationFlag: - + variant_splits: Optional[Dict[str, float]] = None, + hash_salt: Optional[str] = None) -> ExperimentationFlag: if variants is None: variants = [ Variant(key="control", value="control", is_control=True, split=50.0), @@ -54,7 +54,8 @@ def create_test_flag( ruleset=ruleset, context=context, experiment_id=experiment_id, - is_experiment_active=is_experiment_active + is_experiment_active=is_experiment_active, + hash_salt=hash_salt ) @@ -319,6 +320,54 @@ async def test_get_variant_value_does_not_track_exposure_without_distinct_id(sel _ = self._flags.get_variant_value("nonexistent_flag", "fallback", {"company_id": "company123"}) self._mock_tracker.assert_not_called() + @respx.mock + async def test_get_all_variants_returns_all_variants_when_user_in_rollout(self): + flag1 = create_test_flag(flag_key="flag1", rollout_percentage=100.0) + flag2 = create_test_flag(flag_key="flag2", rollout_percentage=100.0) + await self.setup_flags([flag1, flag2]) + + result = self._flags.get_all_variants({"distinct_id": "user123"}) + + assert len(result) == 2 and "flag1" in result and "flag2" in result + + @respx.mock + async def test_get_all_variants_returns_partial_variants_when_user_in_some_rollout(self): + flag1 = create_test_flag(flag_key="flag1", rollout_percentage=100.0) + flag2 = create_test_flag(flag_key="flag2", rollout_percentage=0.0) + await self.setup_flags([flag1, flag2]) + + result = self._flags.get_all_variants({"distinct_id": "user123"}) + + assert len(result) == 1 and "flag1" in result and "flag2" not in result + + @respx.mock + async def test_get_all_variants_returns_empty_dict_when_no_flags_configured(self): + await self.setup_flags([]) + + result = self._flags.get_all_variants({"distinct_id": "user123"}) + + assert result == {} + + @respx.mock + async def test_get_all_variants_does_not_track_exposure_events(self): + flag1 = create_test_flag(flag_key="flag1", rollout_percentage=100.0) + flag2 = create_test_flag(flag_key="flag2", rollout_percentage=100.0) + await self.setup_flags([flag1, flag2]) + + _ = self._flags.get_all_variants({"distinct_id": "user123"}) + + self._mock_tracker.assert_not_called() + + @respx.mock + async def test_track_exposure_event_successfully_tracks(self): + flag = create_test_flag() + await self.setup_flags([flag]) + + variant = SelectedVariant(key="treatment", variant_value="treatment") + self._flags.track_exposure_event("test_flag", variant, {"distinct_id": "user123"}) + + self._mock_tracker.assert_called_once() + @respx.mock async def test_are_flags_ready_returns_true_when_flags_loaded(self): flag = create_test_flag() diff --git a/mixpanel/flags/test_remote_feature_flags.py b/mixpanel/flags/test_remote_feature_flags.py index def080c..c2e312e 100644 --- a/mixpanel/flags/test_remote_feature_flags.py +++ b/mixpanel/flags/test_remote_feature_flags.py @@ -88,6 +88,58 @@ async def test_ais_enabled_returns_false_for_false_variant_value(self): result = await self._flags.ais_enabled("test_flag", {"distinct_id": "user123"}) assert result == False + @respx.mock + async def test_aget_all_variants_returns_all_variants_from_api(self): + variants = { + "flag1": SelectedVariant(variant_key="treatment1", variant_value="value1"), + "flag2": SelectedVariant(variant_key="treatment2", variant_value="value2") + } + respx.get(ENDPOINT).mock(return_value=create_success_response(variants)) + + result = await self._flags.aget_all_variants({"distinct_id": "user123"}) + + assert result == variants + + @respx.mock + async def test_aget_all_variants_returns_none_on_network_error(self): + respx.get(ENDPOINT).mock(side_effect=httpx.RequestError("Network error")) + + result = await self._flags.aget_all_variants({"distinct_id": "user123"}) + + assert result is None + + @respx.mock + async def test_aget_all_variants_does_not_track_exposure_events(self): + variants = { + "flag1": SelectedVariant(variant_key="treatment1", variant_value="value1"), + "flag2": SelectedVariant(variant_key="treatment2", variant_value="value2") + } + respx.get(ENDPOINT).mock(return_value=create_success_response(variants)) + + await self._flags.aget_all_variants({"distinct_id": "user123"}) + + self.mock_tracker.assert_not_called() + + @respx.mock + async def test_aget_all_variants_handles_empty_response(self): + respx.get(ENDPOINT).mock(return_value=create_success_response({})) + + result = await self._flags.aget_all_variants({"distinct_id": "user123"}) + + assert result == {} + + @respx.mock + async def test_atrack_exposure_event_successfully_tracks(self): + variant = SelectedVariant(variant_key="treatment", variant_value="treatment") + + await self._flags.atrack_exposure_event("test_flag", variant, {"distinct_id": "user123"}) + + pending = [task for task in asyncio.all_tasks() if not task.done() and task != asyncio.current_task()] + if pending: + await asyncio.gather(*pending, return_exceptions=True) + + self.mock_tracker.assert_called_once() + class TestRemoteFeatureFlagsProviderSync: def setup_method(self): config = RemoteFlagsConfig() @@ -157,3 +209,51 @@ def test_is_enabled_returns_false_for_false_variant_value(self): result = self._flags.is_enabled("test_flag", {"distinct_id": "user123"}) assert result == False + @respx.mock + def test_get_all_variants_returns_all_variants_from_api(self): + variants = { + "flag1": SelectedVariant(variant_key="treatment1", variant_value="value1"), + "flag2": SelectedVariant(variant_key="treatment2", variant_value="value2") + } + respx.get(ENDPOINT).mock(return_value=create_success_response(variants)) + + result = self._flags.get_all_variants({"distinct_id": "user123"}) + + assert result == variants + + @respx.mock + def test_get_all_variants_returns_none_on_network_error(self): + respx.get(ENDPOINT).mock(side_effect=httpx.RequestError("Network error")) + + result = self._flags.get_all_variants({"distinct_id": "user123"}) + + assert result is None + + @respx.mock + def test_get_all_variants_does_not_track_exposure_events(self): + variants = { + "flag1": SelectedVariant(variant_key="treatment1", variant_value="value1"), + "flag2": SelectedVariant(variant_key="treatment2", variant_value="value2") + } + respx.get(ENDPOINT).mock(return_value=create_success_response(variants)) + + self._flags.get_all_variants({"distinct_id": "user123"}) + + self.mock_tracker.assert_not_called() + + @respx.mock + def test_get_all_variants_handles_empty_response(self): + respx.get(ENDPOINT).mock(return_value=create_success_response({})) + + result = self._flags.get_all_variants({"distinct_id": "user123"}) + + assert result == {} + + @respx.mock + def test_track_exposure_event_successfully_tracks(self): + variant = SelectedVariant(variant_key="treatment", variant_value="treatment") + + self._flags.track_exposure_event("test_flag", variant, {"distinct_id": "user123"}) + + self.mock_tracker.assert_called_once() + diff --git a/mixpanel/flags/types.py b/mixpanel/flags/types.py index 20fe6ad..3f2d6b7 100644 --- a/mixpanel/flags/types.py +++ b/mixpanel/flags/types.py @@ -49,6 +49,7 @@ class ExperimentationFlag(BaseModel): context: str experiment_id: Optional[str] = None is_experiment_active: Optional[bool] = None + hash_salt: Optional[str] = None class SelectedVariant(BaseModel):