diff --git a/open_feature/hooks/hook_support.py b/open_feature/hooks/hook_support.py index 119359d0..11115044 100644 --- a/open_feature/hooks/hook_support.py +++ b/open_feature/hooks/hook_support.py @@ -17,7 +17,7 @@ def error_hooks( hooks: typing.List[Hook], hints: dict, ): - kwargs = {"ctx": hook_context, "exception": exception, "hints": hints} + kwargs = {"hook_context": hook_context, "exception": exception, "hints": hints} _execute_hooks( flag_type=flag_type, hooks=hooks, hook_method=HookType.ERROR, **kwargs ) @@ -29,7 +29,7 @@ def after_all_hooks( hooks: typing.List[Hook], hints: dict, ): - kwargs = {"ctx": hook_context, "hints": hints} + kwargs = {"hook_context": hook_context, "hints": hints} _execute_hooks( flag_type=flag_type, hooks=hooks, hook_method=HookType.FINALLY_AFTER, **kwargs ) @@ -42,7 +42,7 @@ def after_hooks( hooks: typing.List[Hook], hints: dict, ): - kwargs = {"ctx": hook_context, "details": details, "hints": hints} + kwargs = {"hook_context": hook_context, "details": details, "hints": hints} _execute_hooks_unchecked( flag_type=flag_type, hooks=hooks, hook_method=HookType.AFTER, **kwargs ) @@ -54,7 +54,7 @@ def before_hooks( hooks: typing.List[Hook], hints: dict, ) -> EvaluationContext: - kwargs = {"ctx": hook_context, "hints": hints} + kwargs = {"hook_context": hook_context, "hints": hints} executed_hooks = _execute_hooks_unchecked( flag_type=flag_type, hooks=hooks, hook_method=HookType.BEFORE, **kwargs ) diff --git a/open_feature/open_feature_client.py b/open_feature/open_feature_client.py index a685f4b7..8be455ef 100644 --- a/open_feature/open_feature_client.py +++ b/open_feature/open_feature_client.py @@ -3,7 +3,7 @@ from numbers import Number from open_feature.evaluation_context.evaluation_context import EvaluationContext -from open_feature.exception.exceptions import GeneralError +from open_feature.exception.exceptions import GeneralError, OpenFeatureError from open_feature.flag_evaluation.error_code import ErrorCode from open_feature.flag_evaluation.flag_evaluation_details import FlagEvaluationDetails from open_feature.flag_evaluation.flag_type import FlagType @@ -190,7 +190,7 @@ def evaluate_flag_details( client_metadata=None, provider_metadata=None, ) - merged_hooks = [] + merged_hooks = self.hooks try: # https://github.com/open-feature/spec/blob/main/specification/sections/03-evaluation-context.md @@ -206,7 +206,7 @@ def evaluate_flag_details( api_evaluation_context().merge(self.context).merge(invocation_context) ) - flag_evaluation = self.create_provider_evaluation( + flag_evaluation = self._create_provider_evaluation( flag_type, flag_key, default_value, @@ -217,7 +217,7 @@ def evaluate_flag_details( return flag_evaluation - except OpenFeatureError as e: # noqa + except OpenFeatureError as e: error_hooks(flag_type, hook_context, e, merged_hooks, None) return FlagEvaluationDetails( flag_key=flag_key, @@ -242,7 +242,7 @@ def evaluate_flag_details( finally: after_all_hooks(flag_type, hook_context, merged_hooks, None) - def create_provider_evaluation( + def _create_provider_evaluation( self, flag_type: FlagType, flag_key: str, diff --git a/tests/provider/conftest.py b/tests/conftest.py similarity index 52% rename from tests/provider/conftest.py rename to tests/conftest.py index b2161c68..af8467ff 100644 --- a/tests/provider/conftest.py +++ b/tests/conftest.py @@ -4,6 +4,16 @@ from open_feature.provider.no_op_provider import NoOpProvider +@pytest.fixture(autouse=True) +def clear_provider(): + """ + For tests that use set_provider(), we need to clear the provider to avoid issues + in other tests. + """ + yield + _provider = None # noqa: F841 + + @pytest.fixture() def no_op_provider_client(): api.set_provider(NoOpProvider()) diff --git a/tests/provider/test_no_op_provider.py b/tests/provider/test_no_op_provider.py index aed960ae..1e1e4b96 100644 --- a/tests/provider/test_no_op_provider.py +++ b/tests/provider/test_no_op_provider.py @@ -1,16 +1,9 @@ from numbers import Number -from open_feature import open_feature_api as api from open_feature.provider.no_op_provider import NoOpProvider -def setup(): - api.set_provider(NoOpProvider()) - provider = api.get_provider() - assert isinstance(provider, NoOpProvider) - - -def test_should_return_no_op_provider_metadata(no_op_provider_client): +def test_should_return_no_op_provider_metadata(): # Given # When metadata = NoOpProvider().get_metadata() @@ -20,39 +13,37 @@ def test_should_return_no_op_provider_metadata(no_op_provider_client): assert metadata.is_default_provider -def test_should_get_boolean_flag_from_no_op(no_op_provider_client): +def test_should_get_boolean_flag_from_no_op(): # Given # When - flag = no_op_provider_client.get_boolean_details(flag_key="Key", default_value=True) + flag = NoOpProvider().get_boolean_details(flag_key="Key", default_value=True) # Then assert flag is not None assert flag.value assert isinstance(flag.value, bool) -def test_should_get_number_flag_from_no_op(no_op_provider_client): +def test_should_get_number_flag_from_no_op(): # Given # When - flag = no_op_provider_client.get_number_details(flag_key="Key", default_value=100) + flag = NoOpProvider().get_number_details(flag_key="Key", default_value=100) # Then assert flag is not None assert flag.value == 100 assert isinstance(flag.value, Number) -def test_should_get_string_flag_from_no_op(no_op_provider_client): +def test_should_get_string_flag_from_no_op(): # Given # When - flag = no_op_provider_client.get_string_details( - flag_key="Key", default_value="String" - ) + flag = NoOpProvider().get_string_details(flag_key="Key", default_value="String") # Then assert flag is not None assert flag.value == "String" assert isinstance(flag.value, str) -def test_should_get_object_flag_from_no_op(no_op_provider_client): +def test_should_get_object_flag_from_no_op(): # Given return_value = { "String": "string", @@ -60,9 +51,7 @@ def test_should_get_object_flag_from_no_op(no_op_provider_client): "Boolean": True, } # When - flag = no_op_provider_client.get_string_details( - flag_key="Key", default_value=return_value - ) + flag = NoOpProvider().get_object_details(flag_key="Key", default_value=return_value) # Then assert flag is not None assert flag.value == return_value diff --git a/tests/test_open_feature_api.py b/tests/test_open_feature_api.py new file mode 100644 index 00000000..c3bee46d --- /dev/null +++ b/tests/test_open_feature_api.py @@ -0,0 +1,56 @@ +import pytest + +from open_feature.exception.exceptions import GeneralError +from open_feature.flag_evaluation.error_code import ErrorCode +from open_feature.open_feature_api import get_client, get_provider, set_provider +from open_feature.provider.no_op_provider import NoOpProvider + + +def test_should_raise_exception_with_nop_client(): + # Given + # When + with pytest.raises(GeneralError) as ge: + get_client() + # Then + assert ge.value + assert ( + ge.value.error_message + == "Provider not set. Call set_provider before using get_client" + ) + assert ge.value.error_code == ErrorCode.GENERAL + + +def test_should_return_open_feature_client_when_configured_correctly(): + # Given + set_provider(NoOpProvider()) + + # When + client = get_client(name="No-op Provider", version="1.0") + + # Then + assert client.name == "No-op Provider" + assert client.version == "1.0" + assert isinstance(client.provider, NoOpProvider) + + +def test_should_try_set_provider_and_fail_if_none_provided(): + # Given + # When + with pytest.raises(GeneralError) as ge: + set_provider(provider=None) + + # Then + assert ge.value.error_message == "No provider" + assert ge.value.error_code == ErrorCode.GENERAL + + +def test_should_return_a_provider_if_setup_correctly(): + # Given + set_provider(NoOpProvider()) + + # When + provider = get_provider() + + # Then + assert provider + assert isinstance(provider, NoOpProvider) diff --git a/tests/test_open_feature_client.py b/tests/test_open_feature_client.py new file mode 100644 index 00000000..9e7fb5ee --- /dev/null +++ b/tests/test_open_feature_client.py @@ -0,0 +1,123 @@ +from numbers import Number +from unittest.mock import MagicMock + +import pytest + +from open_feature.exception.exceptions import OpenFeatureError +from open_feature.flag_evaluation.error_code import ErrorCode +from open_feature.flag_evaluation.reason import Reason +from open_feature.hooks.hook import Hook + + +@pytest.mark.parametrize( + "flag_type, default_value, get_method", + ( + (bool, True, "get_boolean_value"), + (str, "String", "get_string_value"), + (Number, 100, "get_number_value"), + ( + dict, + { + "String": "string", + "Number": 2, + "Boolean": True, + }, + "get_object_value", + ), + ), +) +def test_should_get_flag_value_based_on_method_type( + flag_type, default_value, get_method, no_op_provider_client +): + # Given + # When + flag = getattr(no_op_provider_client, get_method)( + flag_key="Key", default_value=default_value + ) + # Then + assert flag is not None + assert flag == default_value + assert isinstance(flag, flag_type) + + +@pytest.mark.parametrize( + "flag_type, default_value, get_method", + ( + (bool, True, "get_boolean_details"), + (str, "String", "get_string_details"), + (Number, 100, "get_number_details"), + ( + dict, + { + "String": "string", + "Number": 2, + "Boolean": True, + }, + "get_object_details", + ), + ), +) +def test_should_get_flag_detail_based_on_method_type( + flag_type, default_value, get_method, no_op_provider_client +): + # Given + # When + flag = getattr(no_op_provider_client, get_method)( + flag_key="Key", default_value=default_value + ) + # Then + assert flag is not None + assert flag.value == default_value + assert isinstance(flag.value, flag_type) + + +def test_should_raise_exception_when_invalid_flag_type_provided(no_op_provider_client): + # Given + # When + flag = no_op_provider_client.evaluate_flag_details( + flag_type=None, flag_key="Key", default_value=True + ) + # Then + assert flag.value + assert flag.error_message == "Unknown flag type" + assert flag.error_code == ErrorCode.GENERAL + assert flag.reason == Reason.ERROR + + +def test_should_handle_a_generic_exception_thrown_by_a_provider(no_op_provider_client): + # Given + exception_hook = MagicMock(spec=Hook) + exception_hook.after.side_effect = Exception("Generic exception raised") + no_op_provider_client.add_hooks([exception_hook]) + # When + flag_details = no_op_provider_client.get_boolean_details( + flag_key="Key", default_value=True + ) + # Then + assert flag_details is not None + assert flag_details.value + assert isinstance(flag_details.value, bool) + assert flag_details.reason == Reason.ERROR + assert flag_details.error_message == "Generic exception raised" + + +def test_should_handle_an_open_feature_exception_thrown_by_a_provider( + no_op_provider_client, +): + # Given + exception_hook = MagicMock(spec=Hook) + exception_hook.after.side_effect = OpenFeatureError( + "error_message", ErrorCode.GENERAL + ) + no_op_provider_client.add_hooks([exception_hook]) + + # When + flag_details = no_op_provider_client.get_boolean_details( + flag_key="Key", default_value=True + ) + # Then + assert flag_details is not None + assert flag_details.value + assert isinstance(flag_details.value, bool) + assert flag_details.reason == Reason.ERROR + assert flag_details.error_message == "error_message" diff --git a/tests/test_open_feature_evaluation_context.py b/tests/test_open_feature_evaluation_context.py new file mode 100644 index 00000000..fc353d96 --- /dev/null +++ b/tests/test_open_feature_evaluation_context.py @@ -0,0 +1,33 @@ +import pytest + +from open_feature.evaluation_context.evaluation_context import EvaluationContext +from open_feature.exception.exceptions import GeneralError +from open_feature.flag_evaluation.error_code import ErrorCode +from open_feature.open_feature_evaluation_context import ( + api_evaluation_context, + set_api_evaluation_context, +) + + +def test_should_raise_an_exception_if_no_evaluation_context_set(): + # Given + with pytest.raises(GeneralError) as ge: + set_api_evaluation_context(evaluation_context=None) + # Then + assert ge.value + assert ge.value.error_message == "No api level evaluation context" + assert ge.value.error_code == ErrorCode.GENERAL + + +def test_should_successfully_set_evaluation_context_for_api(): + # Given + evaluation_context = EvaluationContext("targeting_key", {"attr1": "val1"}) + + # When + set_api_evaluation_context(evaluation_context) + global_evaluation_context = api_evaluation_context() + + # Then + assert global_evaluation_context + assert global_evaluation_context.targeting_key == evaluation_context.targeting_key + assert global_evaluation_context.attributes == evaluation_context.attributes