diff --git a/optimizely/project_config.py b/optimizely/project_config.py index 89c9b48b..1e813462 100644 --- a/optimizely/project_config.py +++ b/optimizely/project_config.py @@ -88,6 +88,38 @@ def __init__(self, datafile: str | bytes, logger: Logger, error_handler: Any): region_value = config.get('region') self.region: str = region_value or 'US' + self.holdouts: list[dict[str, Any]] = config.get('holdouts', []) + self.holdout_id_map: dict[str, dict[str, Any]] = {} + self.global_holdouts: dict[str, dict[str, Any]] = {} + self.included_holdouts: dict[str, list[dict[str, Any]]] = {} + self.excluded_holdouts: dict[str, list[dict[str, Any]]] = {} + self.flag_holdouts_map: dict[str, list[dict[str, Any]]] = {} + + for holdout in self.holdouts: + if holdout.get('status') != 'Running': + continue + + holdout_id = holdout['id'] + self.holdout_id_map[holdout_id] = holdout + + included_flags = holdout.get('includedFlags') + if not included_flags: + # This is a global holdout + self.global_holdouts[holdout_id] = holdout + + excluded_flags = holdout.get('excludedFlags') + if excluded_flags: + for flag_id in excluded_flags: + if flag_id not in self.excluded_holdouts: + self.excluded_holdouts[flag_id] = [] + self.excluded_holdouts[flag_id].append(holdout) + else: + # This holdout applies to specific flags + for flag_id in included_flags: + if flag_id not in self.included_holdouts: + self.included_holdouts[flag_id] = [] + self.included_holdouts[flag_id].append(holdout) + # Utility maps for quick lookup self.group_id_map: dict[str, entities.Group] = self._generate_key_map(self.groups, 'id', entities.Group) self.experiment_id_map: dict[str, entities.Experiment] = self._generate_key_map( @@ -752,3 +784,62 @@ def get_flag_variation( return variation return None + + def get_holdouts_for_flag(self, flag_key: str) -> list[Any]: + """ Helper method to get holdouts from an applied feature flag. + + Args: + flag_key: Key of the feature flag. + + Returns: + The holdouts that apply for a specific flag. + """ + feature_flag = self.feature_key_map.get(flag_key) + if not feature_flag: + return [] + + flag_id = feature_flag.id + + # Check cache first + if flag_id in self.flag_holdouts_map: + return self.flag_holdouts_map[flag_id] + + holdouts = [] + + # Add global holdouts that don't exclude this flag + for holdout in self.global_holdouts.values(): + is_excluded = False + excluded_flags = holdout.get('excludedFlags') + if excluded_flags: + for excluded_flag_id in excluded_flags: + if excluded_flag_id == flag_id: + is_excluded = True + break + if not is_excluded: + holdouts.append(holdout) + + # Add holdouts that specifically include this flag + if flag_id in self.included_holdouts: + holdouts.extend(self.included_holdouts[flag_id]) + + # Cache the result + self.flag_holdouts_map[flag_id] = holdouts + + return holdouts + + def get_holdout(self, holdout_id: str) -> Optional[dict[str, Any]]: + """ Helper method to get holdout from holdout ID. + + Args: + holdout_id: ID of the holdout. + + Returns: + The holdout corresponding to the provided holdout ID. + """ + holdout = self.holdout_id_map.get(holdout_id) + + if holdout: + return holdout + + self.logger.error(f'Holdout with ID "{holdout_id}" not found.') + return None diff --git a/requirements/core.txt b/requirements/core.txt index 7cbfe29f..ea81c17b 100644 --- a/requirements/core.txt +++ b/requirements/core.txt @@ -2,3 +2,4 @@ jsonschema>=3.2.0 pyrsistent>=0.16.0 requests>=2.21 idna>=2.10 +rpds-py<0.20.0; python_version < '3.11' diff --git a/requirements/typing.txt b/requirements/typing.txt index ba65f536..4c01897b 100644 --- a/requirements/typing.txt +++ b/requirements/typing.txt @@ -1,4 +1,5 @@ mypy types-jsonschema types-requests -types-Flask \ No newline at end of file +types-Flask +rpds-py<0.20.0; python_version < '3.11' \ No newline at end of file diff --git a/tests/test_config.py b/tests/test_config.py index a6e828c2..08a81f6d 100644 --- a/tests/test_config.py +++ b/tests/test_config.py @@ -1375,3 +1375,162 @@ def test_get_variation_from_key_by_experiment_id_missing(self): variation = project_config.get_variation_from_key_by_experiment_id(experiment_id, variation_key) self.assertIsNone(variation) + + +class HoldoutConfigTest(base.BaseTest): + def setUp(self): + base.BaseTest.setUp(self) + + # Create config with holdouts + config_body_with_holdouts = self.config_dict_with_features.copy() + + # Use correct feature flag IDs from the datafile + boolean_feature_id = '91111' # boolean_single_variable_feature + multi_variate_feature_id = '91114' # test_feature_in_experiment_and_rollout + + config_body_with_holdouts['holdouts'] = [ + { + 'id': 'holdout_1', + 'key': 'global_holdout', + 'status': 'Running', + 'includedFlags': [], + 'excludedFlags': [boolean_feature_id] + }, + { + 'id': 'holdout_2', + 'key': 'specific_holdout', + 'status': 'Running', + 'includedFlags': [multi_variate_feature_id], + 'excludedFlags': [] + }, + { + 'id': 'holdout_3', + 'key': 'inactive_holdout', + 'status': 'Inactive', + 'includedFlags': [boolean_feature_id], + 'excludedFlags': [] + } + ] + + self.config_json_with_holdouts = json.dumps(config_body_with_holdouts) + opt_obj = optimizely.Optimizely(self.config_json_with_holdouts) + self.config_with_holdouts = opt_obj.config_manager.get_config() + + def test_get_holdouts_for_flag__non_existent_flag(self): + """ Test that get_holdouts_for_flag returns empty array for non-existent flag. """ + + holdouts = self.config_with_holdouts.get_holdouts_for_flag('non_existent_flag') + self.assertEqual([], holdouts) + + def test_get_holdouts_for_flag__returns_global_and_specific_holdouts(self): + """ Test that get_holdouts_for_flag returns global holdouts that do not exclude the flag + and specific holdouts that include the flag. """ + + holdouts = self.config_with_holdouts.get_holdouts_for_flag('test_feature_in_experiment_and_rollout') + self.assertEqual(2, len(holdouts)) + + global_holdout = next((h for h in holdouts if h['key'] == 'global_holdout'), None) + self.assertIsNotNone(global_holdout) + self.assertEqual('holdout_1', global_holdout['id']) + + specific_holdout = next((h for h in holdouts if h['key'] == 'specific_holdout'), None) + self.assertIsNotNone(specific_holdout) + self.assertEqual('holdout_2', specific_holdout['id']) + + def test_get_holdouts_for_flag__excludes_global_holdouts_for_excluded_flags(self): + """ Test that get_holdouts_for_flag does not return global holdouts that exclude the flag. """ + + holdouts = self.config_with_holdouts.get_holdouts_for_flag('boolean_single_variable_feature') + self.assertEqual(0, len(holdouts)) + + global_holdout = next((h for h in holdouts if h['key'] == 'global_holdout'), None) + self.assertIsNone(global_holdout) + + def test_get_holdouts_for_flag__caches_results(self): + """ Test that get_holdouts_for_flag caches results for subsequent calls. """ + + holdouts1 = self.config_with_holdouts.get_holdouts_for_flag('test_feature_in_experiment_and_rollout') + holdouts2 = self.config_with_holdouts.get_holdouts_for_flag('test_feature_in_experiment_and_rollout') + + # Should be the same object (cached) + self.assertIs(holdouts1, holdouts2) + self.assertEqual(2, len(holdouts1)) + + def test_get_holdouts_for_flag__returns_only_global_for_non_targeted_flags(self): + """ Test that get_holdouts_for_flag returns only global holdouts for flags not specifically targeted. """ + + holdouts = self.config_with_holdouts.get_holdouts_for_flag('test_feature_in_rollout') + + # Should only include global holdout (not excluded and no specific targeting) + self.assertEqual(1, len(holdouts)) + self.assertEqual('global_holdout', holdouts[0]['key']) + + def test_get_holdout__returns_holdout_for_valid_id(self): + """ Test that get_holdout returns holdout when valid ID is provided. """ + + holdout = self.config_with_holdouts.get_holdout('holdout_1') + self.assertIsNotNone(holdout) + self.assertEqual('holdout_1', holdout['id']) + self.assertEqual('global_holdout', holdout['key']) + self.assertEqual('Running', holdout['status']) + + def test_get_holdout__returns_holdout_regardless_of_status(self): + """ Test that get_holdout returns holdout regardless of status when valid ID is provided. """ + + holdout = self.config_with_holdouts.get_holdout('holdout_2') + self.assertIsNotNone(holdout) + self.assertEqual('holdout_2', holdout['id']) + self.assertEqual('specific_holdout', holdout['key']) + self.assertEqual('Running', holdout['status']) + + def test_get_holdout__returns_none_for_non_existent_id(self): + """ Test that get_holdout returns None for non-existent holdout ID. """ + + holdout = self.config_with_holdouts.get_holdout('non_existent_holdout') + self.assertIsNone(holdout) + + def test_get_holdout__logs_error_when_not_found(self): + """ Test that get_holdout logs error when holdout is not found. """ + + with mock.patch.object(self.config_with_holdouts, 'logger') as mock_logger: + result = self.config_with_holdouts.get_holdout('invalid_holdout_id') + + self.assertIsNone(result) + mock_logger.error.assert_called_once_with('Holdout with ID "invalid_holdout_id" not found.') + + def test_get_holdout__does_not_log_when_found(self): + """ Test that get_holdout does not log when holdout is found. """ + + with mock.patch.object(self.config_with_holdouts, 'logger') as mock_logger: + result = self.config_with_holdouts.get_holdout('holdout_1') + + self.assertIsNotNone(result) + mock_logger.error.assert_not_called() + + def test_holdout_initialization__categorizes_holdouts_properly(self): + """ Test that holdouts are properly categorized during initialization. """ + + self.assertIn('holdout_1', self.config_with_holdouts.holdout_id_map) + self.assertIn('holdout_2', self.config_with_holdouts.holdout_id_map) + self.assertIn('holdout_1', self.config_with_holdouts.global_holdouts) + + # Use correct feature flag IDs + boolean_feature_id = '91111' + multi_variate_feature_id = '91114' + + self.assertIn(multi_variate_feature_id, self.config_with_holdouts.included_holdouts) + self.assertTrue(len(self.config_with_holdouts.included_holdouts[multi_variate_feature_id]) > 0) + self.assertNotIn(boolean_feature_id, self.config_with_holdouts.included_holdouts) + + self.assertIn(boolean_feature_id, self.config_with_holdouts.excluded_holdouts) + self.assertTrue(len(self.config_with_holdouts.excluded_holdouts[boolean_feature_id]) > 0) + + def test_holdout_initialization__only_processes_running_holdouts(self): + """ Test that only running holdouts are processed during initialization. """ + + self.assertNotIn('holdout_3', self.config_with_holdouts.holdout_id_map) + self.assertNotIn('holdout_3', self.config_with_holdouts.global_holdouts) + + boolean_feature_id = '91111' + included_for_boolean = self.config_with_holdouts.included_holdouts.get(boolean_feature_id) + self.assertIsNone(included_for_boolean)