From 906d90232c1ae628d584b4ecee320d71b8a75a96 Mon Sep 17 00:00:00 2001 From: Kwame Efah <37164746+efahk@users.noreply.github.com> Date: Mon, 20 Oct 2025 23:37:48 +0000 Subject: [PATCH 1/6] Updates to flag providers --- mixpanel/flags/local_feature_flags.py | 53 ++++++++++++++++++++++---- mixpanel/flags/remote_feature_flags.py | 8 ++-- mixpanel/flags/types.py | 12 +++++- mixpanel/flags/utils.py | 18 ++++++++- 4 files changed, 77 insertions(+), 14 deletions(-) diff --git a/mixpanel/flags/local_feature_flags.py b/mixpanel/flags/local_feature_flags.py index 6d95334..e478f60 100644 --- a/mixpanel/flags/local_feature_flags.py +++ b/mixpanel/flags/local_feature_flags.py @@ -4,7 +4,7 @@ import time import threading from datetime import datetime, timedelta -from typing import Dict, Any, Callable, Optional +from typing import List, Dict, Any, Callable, Optional from .types import ( ExperimentationFlag, ExperimentationFlags, @@ -17,6 +17,7 @@ normalized_hash, prepare_common_query_params, EXPOSURE_EVENT, + generate_traceparent, ) logger = logging.getLogger(__name__) @@ -170,8 +171,24 @@ def is_enabled(self, flag_key: str, context: Dict[str, Any]) -> bool: variant_value = self.get_variant_value(flag_key, False, context) return bool(variant_value) + def get_all_variants(self, context: Dict[str, Any], reportExposureEvents: bool = False) -> List[SelectedVariant]: + """ + Gets the selected variant for all feature flags that the current user context is in the rollout for. + :param Dict[str, Any] context: The user context to evaluate against the feature flags + :param bool reportExposureEvents: Whether to immediately report exposure events to your Mixpanel project for each flag evaluated. Defaults to False. + """ + variants: Dict[str, SelectedVariant] = {} + fallback = SelectedVariant(variant_key=None, variant_value=None) + + for flag_key in self._flag_definitions.keys(): + variant = self.get_variant(flag_key, fallback, context, report_exposure=reportExposureEvents) + if variant.variant_key is not None: + variants[flag_key] = variant + + return variants + def get_variant( - self, flag_key: str, fallback_value: SelectedVariant, context: Dict[str, Any] + self, flag_key: str, fallback_value: SelectedVariant, context: Dict[str, Any], report_exposure: bool = True ) -> SelectedVariant: """ Gets the selected variant for a feature flag @@ -179,6 +196,7 @@ def get_variant( :param str flag_key: The key of the feature flag to evaluate :param SelectedVariant fallback_value: The default variant to return if evaluation fails :param Dict[str, Any] context: Context dictionary containing user's distinct_id and any other attributes needed for rollout evaluation + :param bool report_exposure: Whether to track an exposure event for this flag evaluation. Defaults to True. """ start_time = time.perf_counter() flag_definition = self._flag_definitions.get(flag_key) @@ -205,7 +223,8 @@ def get_variant( flag_definition, context_value, flag_key, rollout ) end_time = time.perf_counter() - self._track_exposure(flag_key, variant, end_time - start_time, context) + if report_exposure: + self._track_exposure(flag_key, variant, end_time - start_time, context) return variant logger.info( @@ -241,12 +260,17 @@ def _get_assigned_variant( ): return variant - variants = flag_definition.ruleset.variants hash_input = str(context_value) + flag_name variant_hash = normalized_hash(hash_input, "variant") + variants = [variant.model_copy(deep=True) for variant in flag_definition.ruleset.variants] + if rollout.variant_splits: + for variant in variants: + if variant.key in rollout.variant_splits: + variant.split = rollout.variant_splits[variant.key] + selected = variants[0] cumulative = 0.0 for variant in variants: @@ -255,7 +279,11 @@ def _get_assigned_variant( if variant_hash < cumulative: break - return SelectedVariant(variant_key=selected.key, variant_value=selected.value) + return SelectedVariant( + variant_key=selected.key, + variant_value=selected.value, + experiment_id=flag_definition.experiment_id, + is_experiment_active=flag_definition.is_experiment_active) def _get_assigned_rollout( self, @@ -304,15 +332,20 @@ def _get_matching_variant( for variant in flag.ruleset.variants: if variant_key.casefold() == variant.key.casefold(): return SelectedVariant( - variant_key=variant.key, variant_value=variant.value + variant_key=variant.key, + variant_value=variant.value, + experiment_id=flag.experiment_id, + is_experiment_active=flag.is_experiment_active, + is_qa_tester=True, ) return None async def _afetch_flag_definitions(self) -> None: try: start_time = datetime.now() + headers = {"traceparent": generate_traceparent()} response = await self._async_client.get( - self.FLAGS_DEFINITIONS_URL_PATH, params=self._request_params + self.FLAGS_DEFINITIONS_URL_PATH, params=self._request_params, headers=headers ) end_time = datetime.now() self._handle_response(response, start_time, end_time) @@ -322,8 +355,9 @@ async def _afetch_flag_definitions(self) -> None: def _fetch_flag_definitions(self) -> None: try: start_time = datetime.now() + headers = {"traceparent": generate_traceparent()} response = self._sync_client.get( - self.FLAGS_DEFINITIONS_URL_PATH, params=self._request_params + self.FLAGS_DEFINITIONS_URL_PATH, params=self._request_params, headers=headers ) end_time = datetime.now() self._handle_response(response, start_time, end_time) @@ -370,6 +404,9 @@ def _track_exposure( "$experiment_type": "feature_flag", "Flag evaluation mode": "local", "Variant fetch latency (ms)": latency_in_seconds * 1000, + "$experiment_id": variant.experiment_id, + "$is_experiment_active": variant.is_experiment_active, + "$is_qa_tester": variant.is_qa_tester, } self._tracker(distinct_id, EXPOSURE_EVENT, properties) diff --git a/mixpanel/flags/remote_feature_flags.py b/mixpanel/flags/remote_feature_flags.py index 5d7f5d9..6950a69 100644 --- a/mixpanel/flags/remote_feature_flags.py +++ b/mixpanel/flags/remote_feature_flags.py @@ -8,7 +8,7 @@ from asgiref.sync import sync_to_async from .types import RemoteFlagsConfig, SelectedVariant, RemoteFlagsResponse -from .utils import REQUEST_HEADERS, EXPOSURE_EVENT, prepare_common_query_params +from .utils import REQUEST_HEADERS, EXPOSURE_EVENT, prepare_common_query_params, generate_traceparent logger = logging.getLogger(__name__) logging.getLogger("httpx").setLevel(logging.ERROR) @@ -66,7 +66,8 @@ async def aget_variant( try: params = self._prepare_query_params(flag_key, context) start_time = datetime.now() - response = await self._async_client.get(self.FLAGS_URL_PATH, params=params) + headers = {"traceparent": generate_traceparent()} + response = await self._async_client.get(self.FLAGS_URL_PATH, params=params, headers=headers) end_time = datetime.now() self._instrument_call(start_time, end_time) selected_variant, is_fallback = self._handle_response( @@ -126,7 +127,8 @@ def get_variant( try: params = self._prepare_query_params(flag_key, context) start_time = datetime.now() - response = self._sync_client.get(self.FLAGS_URL_PATH, params=params) + headers = {"traceparent": generate_traceparent()} + response = self._sync_client.get(self.FLAGS_URL_PATH, params=params, headers=headers) end_time = datetime.now() self._instrument_call(start_time, end_time) selected_variant, is_fallback = self._handle_response( diff --git a/mixpanel/flags/types.py b/mixpanel/flags/types.py index 186371d..0bcdb83 100644 --- a/mixpanel/flags/types.py +++ b/mixpanel/flags/types.py @@ -32,6 +32,7 @@ class Rollout(BaseModel): rollout_percentage: float runtime_evaluation_definition: Optional[Dict[str, str]] = None variant_override: Optional[VariantOverride] = None + variant_splits: Optional[Dict[str,float]] = None class RuleSet(BaseModel): variants: List[Variant] @@ -41,16 +42,23 @@ class RuleSet(BaseModel): class ExperimentationFlag(BaseModel): id: str name: str - key: str + key: str status: str project_id: int - ruleset: RuleSet + ruleset: RuleSet context: str + experiment_id: Optional[str] = None + is_experiment_active: Optional[bool] = None + class SelectedVariant(BaseModel): # variant_key can be None if being used as a fallback variant_key: Optional[str] = None variant_value: Any + experiment_id: Optional[str] = None + is_experiment_active: Optional[bool] = None + is_qa_tester: Optional[bool] = None + class ExperimentationFlags(BaseModel): flags: List[ExperimentationFlag] diff --git a/mixpanel/flags/utils.py b/mixpanel/flags/utils.py index 987392b..f89ff77 100644 --- a/mixpanel/flags/utils.py +++ b/mixpanel/flags/utils.py @@ -1,4 +1,5 @@ from typing import Dict +from uuid import uuid EXPOSURE_EVENT = "$experiment_started" @@ -47,4 +48,19 @@ def prepare_common_query_params(token: str, sdk_version: str) -> Dict[str, str]: 'token': token } - return params \ No newline at end of file + return params + +def generate_traceparent() -> str: + """ Generates a W3C traceparent header for easy interop with distributed tracing systems i.e Open Telemetry + https://www.w3.org/TR/trace-context/#traceparent-header + :return: A traceparent string + """ + + trace_id = uuid.uuid4().hex + span_id = uuid.uuid4().hex[:16] + + # Trace flags: '01' for sampled + trace_flags = '01' + + traceparent = f"00-{trace_id}-{span_id}-{trace_flags}" + return traceparent \ No newline at end of file From 4b3acb176392ccdb2f9b18e45ba9952a5221e8ea Mon Sep 17 00:00:00 2001 From: Kwame Efah <37164746+efahk@users.noreply.github.com> Date: Tue, 21 Oct 2025 17:23:00 +0000 Subject: [PATCH 2/6] Add additional updates --- mixpanel/flags/local_feature_flags.py | 25 ++++--------------------- mixpanel/flags/remote_feature_flags.py | 9 ++++----- mixpanel/flags/utils.py | 15 +++++++++++---- 3 files changed, 19 insertions(+), 30 deletions(-) diff --git a/mixpanel/flags/local_feature_flags.py b/mixpanel/flags/local_feature_flags.py index e478f60..bf78347 100644 --- a/mixpanel/flags/local_feature_flags.py +++ b/mixpanel/flags/local_feature_flags.py @@ -17,7 +17,7 @@ normalized_hash, prepare_common_query_params, EXPOSURE_EVENT, - generate_traceparent, + add_traceparent_header_to_request ) logger = logging.getLogger(__name__) @@ -50,6 +50,7 @@ def __init__( "headers": REQUEST_HEADERS, "auth": httpx.BasicAuth(token, ""), "timeout": httpx.Timeout(config.request_timeout_in_seconds), + "event_hooks": {"request": [add_traceparent_header_to_request]}, } self._request_params = prepare_common_query_params(self._token, self._version) @@ -171,22 +172,6 @@ def is_enabled(self, flag_key: str, context: Dict[str, Any]) -> bool: variant_value = self.get_variant_value(flag_key, False, context) return bool(variant_value) - def get_all_variants(self, context: Dict[str, Any], reportExposureEvents: bool = False) -> List[SelectedVariant]: - """ - Gets the selected variant for all feature flags that the current user context is in the rollout for. - :param Dict[str, Any] context: The user context to evaluate against the feature flags - :param bool reportExposureEvents: Whether to immediately report exposure events to your Mixpanel project for each flag evaluated. Defaults to False. - """ - variants: Dict[str, SelectedVariant] = {} - fallback = SelectedVariant(variant_key=None, variant_value=None) - - for flag_key in self._flag_definitions.keys(): - variant = self.get_variant(flag_key, fallback, context, report_exposure=reportExposureEvents) - if variant.variant_key is not None: - variants[flag_key] = variant - - return variants - def get_variant( self, flag_key: str, fallback_value: SelectedVariant, context: Dict[str, Any], report_exposure: bool = True ) -> SelectedVariant: @@ -343,9 +328,8 @@ def _get_matching_variant( async def _afetch_flag_definitions(self) -> None: try: start_time = datetime.now() - headers = {"traceparent": generate_traceparent()} response = await self._async_client.get( - self.FLAGS_DEFINITIONS_URL_PATH, params=self._request_params, headers=headers + self.FLAGS_DEFINITIONS_URL_PATH, params=self._request_params, ) end_time = datetime.now() self._handle_response(response, start_time, end_time) @@ -355,9 +339,8 @@ async def _afetch_flag_definitions(self) -> None: def _fetch_flag_definitions(self) -> None: try: start_time = datetime.now() - headers = {"traceparent": generate_traceparent()} response = self._sync_client.get( - self.FLAGS_DEFINITIONS_URL_PATH, params=self._request_params, headers=headers + self.FLAGS_DEFINITIONS_URL_PATH, params=self._request_params, ) end_time = datetime.now() self._handle_response(response, start_time, end_time) diff --git a/mixpanel/flags/remote_feature_flags.py b/mixpanel/flags/remote_feature_flags.py index 6950a69..9518530 100644 --- a/mixpanel/flags/remote_feature_flags.py +++ b/mixpanel/flags/remote_feature_flags.py @@ -8,7 +8,7 @@ from asgiref.sync import sync_to_async from .types import RemoteFlagsConfig, SelectedVariant, RemoteFlagsResponse -from .utils import REQUEST_HEADERS, EXPOSURE_EVENT, prepare_common_query_params, generate_traceparent +from .utils import REQUEST_HEADERS, EXPOSURE_EVENT, prepare_common_query_params, add_traceparent_header_to_request logger = logging.getLogger(__name__) logging.getLogger("httpx").setLevel(logging.ERROR) @@ -30,6 +30,7 @@ def __init__( "headers": REQUEST_HEADERS, "auth": httpx.BasicAuth(token, ""), "timeout": httpx.Timeout(config.request_timeout_in_seconds), + "event_hooks": {"request": [add_traceparent_header_to_request]}, } self._async_client: httpx.AsyncClient = httpx.AsyncClient( @@ -66,8 +67,7 @@ async def aget_variant( try: params = self._prepare_query_params(flag_key, context) start_time = datetime.now() - headers = {"traceparent": generate_traceparent()} - response = await self._async_client.get(self.FLAGS_URL_PATH, params=params, headers=headers) + response = await self._async_client.get(self.FLAGS_URL_PATH, params=params) end_time = datetime.now() self._instrument_call(start_time, end_time) selected_variant, is_fallback = self._handle_response( @@ -127,8 +127,7 @@ def get_variant( try: params = self._prepare_query_params(flag_key, context) start_time = datetime.now() - headers = {"traceparent": generate_traceparent()} - response = self._sync_client.get(self.FLAGS_URL_PATH, params=params, headers=headers) + response = self._sync_client.get(self.FLAGS_URL_PATH, params=params) end_time = datetime.now() self._instrument_call(start_time, end_time) selected_variant, is_fallback = self._handle_response( diff --git a/mixpanel/flags/utils.py b/mixpanel/flags/utils.py index f89ff77..2106bb9 100644 --- a/mixpanel/flags/utils.py +++ b/mixpanel/flags/utils.py @@ -1,5 +1,6 @@ +import uuid +import httpx from typing import Dict -from uuid import uuid EXPOSURE_EVENT = "$experiment_started" @@ -51,11 +52,10 @@ def prepare_common_query_params(token: str, sdk_version: str) -> Dict[str, str]: return params def generate_traceparent() -> str: - """ Generates a W3C traceparent header for easy interop with distributed tracing systems i.e Open Telemetry + """Generates a W3C traceparent header for easy interop with distributed tracing systems i.e Open Telemetry https://www.w3.org/TR/trace-context/#traceparent-header :return: A traceparent string """ - trace_id = uuid.uuid4().hex span_id = uuid.uuid4().hex[:16] @@ -63,4 +63,11 @@ def generate_traceparent() -> str: trace_flags = '01' traceparent = f"00-{trace_id}-{span_id}-{trace_flags}" - return traceparent \ No newline at end of file + return traceparent + +def add_traceparent_header_to_request(request: httpx.Request) -> None: + """Adds a W3C traceparent header to an outgoing HTTPX request for distributed tracing + :param request: The HTTPX request object + """ + traceparent = generate_traceparent() + request.headers['traceparent'] = traceparent \ No newline at end of file From 119373ed1915a16d6f7cdfa76b7aa55f68c9b819 Mon Sep 17 00:00:00 2001 From: Kwame Efah <37164746+efahk@users.noreply.github.com> Date: Tue, 21 Oct 2025 23:57:48 +0000 Subject: [PATCH 3/6] Add tests --- mixpanel/flags/local_feature_flags.py | 31 ++++----- mixpanel/flags/remote_feature_flags.py | 13 ++-- mixpanel/flags/test_local_feature_flags.py | 76 +++++++++++++++++++++- mixpanel/flags/test_utils.py | 39 +++++++++++ mixpanel/flags/types.py | 2 +- mixpanel/flags/utils.py | 9 +-- 6 files changed, 137 insertions(+), 33 deletions(-) create mode 100644 mixpanel/flags/test_utils.py diff --git a/mixpanel/flags/local_feature_flags.py b/mixpanel/flags/local_feature_flags.py index bf78347..4b70132 100644 --- a/mixpanel/flags/local_feature_flags.py +++ b/mixpanel/flags/local_feature_flags.py @@ -4,7 +4,7 @@ import time import threading from datetime import datetime, timedelta -from typing import List, Dict, Any, Callable, Optional +from typing import Dict, Any, Callable, Optional from .types import ( ExperimentationFlag, ExperimentationFlags, @@ -17,7 +17,7 @@ normalized_hash, prepare_common_query_params, EXPOSURE_EVENT, - add_traceparent_header_to_request + generate_traceparent ) logger = logging.getLogger(__name__) @@ -50,7 +50,6 @@ def __init__( "headers": REQUEST_HEADERS, "auth": httpx.BasicAuth(token, ""), "timeout": httpx.Timeout(config.request_timeout_in_seconds), - "event_hooks": {"request": [add_traceparent_header_to_request]}, } self._request_params = prepare_common_query_params(self._token, self._version) @@ -170,7 +169,7 @@ def is_enabled(self, flag_key: str, context: Dict[str, Any]) -> bool: :param Dict[str, Any] context: Context dictionary containing user's distinct_id and any other attributes needed for rollout evaluation """ variant_value = self.get_variant_value(flag_key, False, context) - return bool(variant_value) + return variant_value == True def get_variant( self, flag_key: str, fallback_value: SelectedVariant, context: Dict[str, Any], report_exposure: bool = True @@ -196,21 +195,21 @@ def get_variant( ) return fallback_value + selected_variant: Optional[SelectedVariant] = None + if test_user_variant := self._get_variant_override_for_test_user( flag_definition, context ): - return test_user_variant - - if rollout := self._get_assigned_rollout( - flag_definition, context_value, context - ): - variant = self._get_assigned_variant( + selected_variant = test_user_variant + elif rollout := self._get_assigned_rollout(flag_definition, context_value, context): + selected_variant = self._get_assigned_variant( flag_definition, context_value, flag_key, rollout ) + + if report_exposure and selected_variant is not None: end_time = time.perf_counter() - if report_exposure: - self._track_exposure(flag_key, variant, end_time - start_time, context) - return variant + self._track_exposure(flag_key, selected_variant, end_time - start_time, context) + return selected_variant logger.info( f"{flag_definition.context} context {context_value} not eligible for any rollout for flag: {flag_key}" @@ -328,8 +327,9 @@ def _get_matching_variant( async def _afetch_flag_definitions(self) -> None: try: start_time = datetime.now() + headers = {"traceparent": generate_traceparent()} response = await self._async_client.get( - self.FLAGS_DEFINITIONS_URL_PATH, params=self._request_params, + self.FLAGS_DEFINITIONS_URL_PATH, params=self._request_params, headers=headers ) end_time = datetime.now() self._handle_response(response, start_time, end_time) @@ -339,8 +339,9 @@ async def _afetch_flag_definitions(self) -> None: def _fetch_flag_definitions(self) -> None: try: start_time = datetime.now() + headers = {"traceparent": generate_traceparent()} response = self._sync_client.get( - self.FLAGS_DEFINITIONS_URL_PATH, params=self._request_params, + self.FLAGS_DEFINITIONS_URL_PATH, params=self._request_params, headers=headers ) end_time = datetime.now() self._handle_response(response, start_time, end_time) diff --git a/mixpanel/flags/remote_feature_flags.py b/mixpanel/flags/remote_feature_flags.py index 9518530..af62c74 100644 --- a/mixpanel/flags/remote_feature_flags.py +++ b/mixpanel/flags/remote_feature_flags.py @@ -8,7 +8,7 @@ from asgiref.sync import sync_to_async from .types import RemoteFlagsConfig, SelectedVariant, RemoteFlagsResponse -from .utils import REQUEST_HEADERS, EXPOSURE_EVENT, prepare_common_query_params, add_traceparent_header_to_request +from .utils import REQUEST_HEADERS, EXPOSURE_EVENT, prepare_common_query_params, generate_traceparent logger = logging.getLogger(__name__) logging.getLogger("httpx").setLevel(logging.ERROR) @@ -30,7 +30,6 @@ def __init__( "headers": REQUEST_HEADERS, "auth": httpx.BasicAuth(token, ""), "timeout": httpx.Timeout(config.request_timeout_in_seconds), - "event_hooks": {"request": [add_traceparent_header_to_request]}, } self._async_client: httpx.AsyncClient = httpx.AsyncClient( @@ -67,7 +66,8 @@ async def aget_variant( try: params = self._prepare_query_params(flag_key, context) start_time = datetime.now() - response = await self._async_client.get(self.FLAGS_URL_PATH, params=params) + headers = {"traceparent": generate_traceparent()} + response = await self._async_client.get(self.FLAGS_URL_PATH, params=params, headers=headers) end_time = datetime.now() self._instrument_call(start_time, end_time) selected_variant, is_fallback = self._handle_response( @@ -97,7 +97,7 @@ async def ais_enabled(self, flag_key: str, context: Dict[str, Any]) -> bool: :param Dict[str, Any] context: Context dictionary containing user attributes and rollout context """ variant_value = await self.aget_variant_value(flag_key, False, context) - return bool(variant_value) + return variant_value == True def get_variant_value( self, flag_key: str, fallback_value: Any, context: Dict[str, Any] @@ -127,7 +127,8 @@ def get_variant( try: params = self._prepare_query_params(flag_key, context) start_time = datetime.now() - response = self._sync_client.get(self.FLAGS_URL_PATH, params=params) + headers = {"traceparent": generate_traceparent()} + response = self._sync_client.get(self.FLAGS_URL_PATH, params=params, headers=headers) end_time = datetime.now() self._instrument_call(start_time, end_time) selected_variant, is_fallback = self._handle_response( @@ -153,7 +154,7 @@ def is_enabled(self, flag_key: str, context: Dict[str, Any]) -> bool: :param Dict[str, Any] context: Context dictionary containing user attributes and rollout context """ variant_value = self.get_variant_value(flag_key, False, context) - return bool(variant_value) + return variant_value == True def _prepare_query_params( self, flag_key: str, context: Dict[str, Any] diff --git a/mixpanel/flags/test_local_feature_flags.py b/mixpanel/flags/test_local_feature_flags.py index 9019a01..dba1d20 100644 --- a/mixpanel/flags/test_local_feature_flags.py +++ b/mixpanel/flags/test_local_feature_flags.py @@ -9,6 +9,7 @@ from .types import LocalFlagsConfig, ExperimentationFlag, RuleSet, Variant, Rollout, FlagTestUsers, ExperimentationFlags, VariantOverride from .local_feature_flags import LocalFeatureFlagsProvider + def create_test_flag( flag_key: str = "test_flag", context: str = "distinct_id", @@ -16,7 +17,10 @@ def create_test_flag( variant_override: Optional[VariantOverride] = None, rollout_percentage: float = 100.0, runtime_evaluation: Optional[Dict] = None, - test_users: Optional[Dict[str, str]] = None) -> ExperimentationFlag: + test_users: Optional[Dict[str, str]] = None, + experiment_id: Optional[str] = None, + is_experiment_active: Optional[bool] = None, + variant_splits: Optional[Dict[str, float]] = None) -> ExperimentationFlag: if variants is None: variants = [ @@ -27,7 +31,8 @@ def create_test_flag( rollouts = [Rollout( rollout_percentage=rollout_percentage, runtime_evaluation_definition=runtime_evaluation, - variant_override=variant_override + variant_override=variant_override, + variant_splits=variant_splits )] test_config = None @@ -47,7 +52,9 @@ def create_test_flag( status="active", project_id=123, ruleset=ruleset, - context=context + context=context, + experiment_id=experiment_id, + is_experiment_active=is_experiment_active ) @@ -216,6 +223,32 @@ async def test_get_variant_value_picks_correct_variant_with_hundred_percent_spli result = self._flags.get_variant_value("test_flag", "fallback", {"distinct_id": "user123"}) assert result == "variant_a" + @respx.mock + async def test_get_variant_value_picks_correct_variant_with_half_migrated_group_splits(self): + variants = [ + Variant(key="A", value="variant_a", is_control=False, split=100.0), + Variant(key="B", value="variant_b", is_control=False, split=0.0), + Variant(key="C", value="variant_c", is_control=False, split=0.0) + ] + variant_splits = {"A": 0.0, "B": 100.0, "C": 0.0} + flag = create_test_flag(variants=variants, rollout_percentage=100.0, variant_splits=variant_splits) + await self.setup_flags([flag]) + result = self._flags.get_variant_value("test_flag", "fallback", {"distinct_id": "user123"}) + assert result == "variant_b" + + @respx.mock + async def test_get_variant_value_picks_correct_variant_with_full_migrated_group_splits(self): + variants = [ + Variant(key="A", value="variant_a", is_control=False), + Variant(key="B", value="variant_b", is_control=False), + Variant(key="C", value="variant_c", is_control=False), + ] + variant_splits = {"A": 0.0, "B": 0.0, "C": 100.0} + flag = create_test_flag(variants=variants, rollout_percentage=100.0, variant_splits=variant_splits) + await self.setup_flags([flag]) + result = self._flags.get_variant_value("test_flag", "fallback", {"distinct_id": "user123"}) + assert result == "variant_c" + @respx.mock async def test_get_variant_value_picks_overriden_variant(self): variants = [ @@ -236,6 +269,43 @@ async def test_get_variant_value_tracks_exposure_when_variant_selected(self): _ = self._flags.get_variant_value("test_flag", "fallback", {"distinct_id": "user123"}) self._mock_tracker.assert_called_once() + @respx.mock + @pytest.mark.parametrize("experiment_id,is_experiment_active,use_qa_user", [ + ("exp-123", True, True), # QA tester with active experiment + ("exp-456", False, True), # QA tester with inactive experiment + ("exp-789", True, False), # Regular user with active experiment + ("exp-000", False, False), # Regular user with inactive experiment + (None, None, True), # QA tester with no experiment + (None, None, False), # Regular user with no experiment + ]) + async def test_get_variant_value_tracks_exposure_with_correct_properties(self, experiment_id, is_experiment_active, use_qa_user): + flag = create_test_flag( + experiment_id=experiment_id, + is_experiment_active=is_experiment_active, + test_users={"qa_user": "treatment"} + ) + + await self.setup_flags([flag]) + + distinct_id = "qa_user" if use_qa_user else "regular_user" + + with patch('mixpanel.flags.utils.normalized_hash') as mock_hash: + mock_hash.return_value = 0.5 + _ = self._flags.get_variant_value("test_flag", "fallback", {"distinct_id": distinct_id}) + + self._mock_tracker.assert_called_once() + + call_args = self._mock_tracker.call_args + properties = call_args[0][2] + + assert properties["$experiment_id"] == experiment_id + assert properties["$is_experiment_active"] == is_experiment_active + + if use_qa_user: + assert properties["$is_qa_tester"] == True + else: + assert properties.get("$is_qa_tester") is None + @respx.mock async def test_get_variant_value_does_not_track_exposure_on_fallback(self): await self.setup_flags([]) diff --git a/mixpanel/flags/test_utils.py b/mixpanel/flags/test_utils.py new file mode 100644 index 0000000..8011827 --- /dev/null +++ b/mixpanel/flags/test_utils.py @@ -0,0 +1,39 @@ +import re +import pytest +import random +import string +from .utils import generate_traceparent, normalized_hash + +class TestUtils: + def test_traceparent_format_is_correct(self): + traceparent = generate_traceparent() + + # W3C traceparent format: 00-{32 hex chars}-{16 hex chars}-{2 hex chars} + # https://www.w3.org/TR/trace-context/#traceparent-header + pattern = r'^00-[0-9a-f]{32}-[0-9a-f]{16}-01$' + + assert re.match(pattern, traceparent), f"Traceparent '{traceparent}' does not match W3C format" + + def test_traceparent_pseudo_randomness(self): + traceparents = set() + + for _ in range(100): + traceparents.add(generate_traceparent()) + + assert len(traceparents) == 100, f"Expected 100 unique traceparents, got {len(traceparents)}" + + @pytest.mark.parametrize("key,salt,expected_hash", [ + ("abc", "variant", 0.72), + ("def", "variant", 0.21), + ]) + def test_normalized_hash_for_known_inputs(self, key, salt, expected_hash): + result = normalized_hash(key, salt) + assert result == expected_hash, f"Expected hash of {expected_hash} for '{key}' with salt '{salt}', got {result}" + + def test_normalized_hash_is_between_0_and_1(self): + for _ in range(100): + length = random.randint(5, 20) + random_string = ''.join(random.choices(string.ascii_letters + string.digits, k=length)) + random_salt = ''.join(random.choices(string.ascii_letters, k=10)) + result = normalized_hash(random_string, random_salt) + assert 0.0 <= result < 1.0, f"Hash value {result} is not in range [0, 1] for input '{random_string}'" diff --git a/mixpanel/flags/types.py b/mixpanel/flags/types.py index 0bcdb83..f206962 100644 --- a/mixpanel/flags/types.py +++ b/mixpanel/flags/types.py @@ -20,7 +20,7 @@ class Variant(BaseModel): key: str value: Any is_control: bool - split: float + split: Optional[float] = None class FlagTestUsers(BaseModel): users: Dict[str, str] diff --git a/mixpanel/flags/utils.py b/mixpanel/flags/utils.py index 2106bb9..863a705 100644 --- a/mixpanel/flags/utils.py +++ b/mixpanel/flags/utils.py @@ -63,11 +63,4 @@ def generate_traceparent() -> str: trace_flags = '01' traceparent = f"00-{trace_id}-{span_id}-{trace_flags}" - return traceparent - -def add_traceparent_header_to_request(request: httpx.Request) -> None: - """Adds a W3C traceparent header to an outgoing HTTPX request for distributed tracing - :param request: The HTTPX request object - """ - traceparent = generate_traceparent() - request.headers['traceparent'] = traceparent \ No newline at end of file + return traceparent \ No newline at end of file From 6c0b02d9b7a331174ea885037827acdac7834f0d Mon Sep 17 00:00:00 2001 From: Kwame Efah <37164746+efahk@users.noreply.github.com> Date: Thu, 23 Oct 2025 23:36:37 +0000 Subject: [PATCH 4/6] Update test utils --- mixpanel/flags/test_utils.py | 8 -------- 1 file changed, 8 deletions(-) diff --git a/mixpanel/flags/test_utils.py b/mixpanel/flags/test_utils.py index 8011827..f8a1ae1 100644 --- a/mixpanel/flags/test_utils.py +++ b/mixpanel/flags/test_utils.py @@ -14,14 +14,6 @@ def test_traceparent_format_is_correct(self): assert re.match(pattern, traceparent), f"Traceparent '{traceparent}' does not match W3C format" - def test_traceparent_pseudo_randomness(self): - traceparents = set() - - for _ in range(100): - traceparents.add(generate_traceparent()) - - assert len(traceparents) == 100, f"Expected 100 unique traceparents, got {len(traceparents)}" - @pytest.mark.parametrize("key,salt,expected_hash", [ ("abc", "variant", 0.72), ("def", "variant", 0.21), From 386a2cb3606120fab50720cc86dd53e24420c38e Mon Sep 17 00:00:00 2001 From: Kwame Efah <37164746+efahk@users.noreply.github.com> Date: Fri, 24 Oct 2025 21:40:15 +0000 Subject: [PATCH 5/6] Update test utils --- mixpanel/flags/test_utils.py | 10 +--------- 1 file changed, 1 insertion(+), 9 deletions(-) diff --git a/mixpanel/flags/test_utils.py b/mixpanel/flags/test_utils.py index f8a1ae1..b60b514 100644 --- a/mixpanel/flags/test_utils.py +++ b/mixpanel/flags/test_utils.py @@ -20,12 +20,4 @@ def test_traceparent_format_is_correct(self): ]) def test_normalized_hash_for_known_inputs(self, key, salt, expected_hash): result = normalized_hash(key, salt) - assert result == expected_hash, f"Expected hash of {expected_hash} for '{key}' with salt '{salt}', got {result}" - - def test_normalized_hash_is_between_0_and_1(self): - for _ in range(100): - length = random.randint(5, 20) - random_string = ''.join(random.choices(string.ascii_letters + string.digits, k=length)) - random_salt = ''.join(random.choices(string.ascii_letters, k=10)) - result = normalized_hash(random_string, random_salt) - assert 0.0 <= result < 1.0, f"Hash value {result} is not in range [0, 1] for input '{random_string}'" + assert result == expected_hash, f"Expected hash of {expected_hash} for '{key}' with salt '{salt}', got {result}" \ No newline at end of file From df4e5b4f675d6e544e636c4888bcebe2ccd84473 Mon Sep 17 00:00:00 2001 From: Kwame Efah <37164746+efahk@users.noreply.github.com> Date: Fri, 24 Oct 2025 21:42:43 +0000 Subject: [PATCH 6/6] default split --- mixpanel/flags/types.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/mixpanel/flags/types.py b/mixpanel/flags/types.py index f206962..20fe6ad 100644 --- a/mixpanel/flags/types.py +++ b/mixpanel/flags/types.py @@ -20,7 +20,7 @@ class Variant(BaseModel): key: str value: Any is_control: bool - split: Optional[float] = None + split: Optional[float] = 0.0 class FlagTestUsers(BaseModel): users: Dict[str, str]