From 9cafb33d72836a8da0d1bbed37a866b098dc08b3 Mon Sep 17 00:00:00 2001 From: leohoare Date: Tue, 5 Nov 2024 10:33:59 +1100 Subject: [PATCH 1/6] add async methods Signed-off-by: leohoare --- openfeature/client.py | 347 +++++++++++++++++++++++++++++++ openfeature/provider/__init__.py | 75 +++++++ 2 files changed, 422 insertions(+) diff --git a/openfeature/client.py b/openfeature/client.py index 9e4518ec..c61425ef 100644 --- a/openfeature/client.py +++ b/openfeature/client.py @@ -112,6 +112,20 @@ def get_boolean_value( evaluation_context, flag_evaluation_options, ).value + + async def get_boolean_value_async( + self, + flag_key: str, + default_value: bool, + evaluation_context: typing.Optional[EvaluationContext] = None, + flag_evaluation_options: typing.Optional[FlagEvaluationOptions] = None, + ) -> bool: + return await self.get_boolean_details_async( + flag_key, + default_value, + evaluation_context, + flag_evaluation_options, + ).value def get_boolean_details( self, @@ -128,6 +142,21 @@ def get_boolean_details( flag_evaluation_options, ) + async def get_boolean_details_async( + self, + flag_key: str, + default_value: bool, + evaluation_context: typing.Optional[EvaluationContext] = None, + flag_evaluation_options: typing.Optional[FlagEvaluationOptions] = None, + ) -> FlagEvaluationDetails[bool]: + return await self.evaluate_flag_details_async( + FlagType.BOOLEAN, + flag_key, + default_value, + evaluation_context, + flag_evaluation_options, + ) + def get_string_value( self, flag_key: str, @@ -142,6 +171,20 @@ def get_string_value( flag_evaluation_options, ).value + async def get_string_value_async( + self, + flag_key: str, + default_value: str, + evaluation_context: typing.Optional[EvaluationContext] = None, + flag_evaluation_options: typing.Optional[FlagEvaluationOptions] = None, + ) -> str: + return await self.get_string_details_async( + flag_key, + default_value, + evaluation_context, + flag_evaluation_options, + ).value + def get_string_details( self, flag_key: str, @@ -156,6 +199,21 @@ def get_string_details( evaluation_context, flag_evaluation_options, ) + + async def get_string_details_async( + self, + flag_key: str, + default_value: str, + evaluation_context: typing.Optional[EvaluationContext] = None, + flag_evaluation_options: typing.Optional[FlagEvaluationOptions] = None, + ) -> FlagEvaluationDetails[str]: + return await self.evaluate_flag_details_async( + FlagType.STRING, + flag_key, + default_value, + evaluation_context, + flag_evaluation_options, + ) def get_integer_value( self, @@ -171,6 +229,20 @@ def get_integer_value( flag_evaluation_options, ).value + async def get_integer_value_async( + self, + flag_key: str, + default_value: int, + evaluation_context: typing.Optional[EvaluationContext] = None, + flag_evaluation_options: typing.Optional[FlagEvaluationOptions] = None, + ) -> int: + return await self.get_integer_details_async( + flag_key, + default_value, + evaluation_context, + flag_evaluation_options, + ).value + def get_integer_details( self, flag_key: str, @@ -185,6 +257,21 @@ def get_integer_details( evaluation_context, flag_evaluation_options, ) + + async def get_integer_details_async( + self, + flag_key: str, + default_value: int, + evaluation_context: typing.Optional[EvaluationContext] = None, + flag_evaluation_options: typing.Optional[FlagEvaluationOptions] = None, + ) -> FlagEvaluationDetails[int]: + return await self.evaluate_flag_details_async( + FlagType.INTEGER, + flag_key, + default_value, + evaluation_context, + flag_evaluation_options, + ) def get_float_value( self, @@ -199,6 +286,20 @@ def get_float_value( evaluation_context, flag_evaluation_options, ).value + + async def get_float_value_async( + self, + flag_key: str, + default_value: float, + evaluation_context: typing.Optional[EvaluationContext] = None, + flag_evaluation_options: typing.Optional[FlagEvaluationOptions] = None, + ) -> float: + return await self.get_float_details_async( + flag_key, + default_value, + evaluation_context, + flag_evaluation_options, + ).value def get_float_details( self, @@ -214,6 +315,21 @@ def get_float_details( evaluation_context, flag_evaluation_options, ) + + async def get_float_details_async( + self, + flag_key: str, + default_value: float, + evaluation_context: typing.Optional[EvaluationContext] = None, + flag_evaluation_options: typing.Optional[FlagEvaluationOptions] = None, + ) -> FlagEvaluationDetails[float]: + return await self.evaluate_flag_details_async( + FlagType.FLOAT, + flag_key, + default_value, + evaluation_context, + flag_evaluation_options, + ) def get_object_value( self, @@ -228,6 +344,20 @@ def get_object_value( evaluation_context, flag_evaluation_options, ).value + + async def get_object_value_async( + self, + flag_key: str, + default_value: typing.Union[dict, list], + evaluation_context: typing.Optional[EvaluationContext] = None, + flag_evaluation_options: typing.Optional[FlagEvaluationOptions] = None, + ) -> typing.Union[dict, list]: + return await self.get_object_details_async( + flag_key, + default_value, + evaluation_context, + flag_evaluation_options, + ).value def get_object_details( self, @@ -243,6 +373,21 @@ def get_object_details( evaluation_context, flag_evaluation_options, ) + + async def get_object_details_async( + self, + flag_key: str, + default_value: typing.Union[dict, list], + evaluation_context: typing.Optional[EvaluationContext] = None, + flag_evaluation_options: typing.Optional[FlagEvaluationOptions] = None, + ) -> FlagEvaluationDetails[typing.Union[dict, list]]: + return await self.evaluate_flag_details_async( + FlagType.OBJECT, + flag_key, + default_value, + evaluation_context, + flag_evaluation_options, + ) def evaluate_flag_details( # noqa: PLR0915 self, @@ -391,6 +536,154 @@ def evaluate_flag_details( # noqa: PLR0915 finally: after_all_hooks(flag_type, hook_context, reversed_merged_hooks, hook_hints) + async def evaluate_flag_details_async( # noqa: PLR0915 + self, + flag_type: FlagType, + flag_key: str, + default_value: typing.Any, + evaluation_context: typing.Optional[EvaluationContext] = None, + flag_evaluation_options: typing.Optional[FlagEvaluationOptions] = None, + ) -> FlagEvaluationDetails[typing.Any]: + """ + Evaluate the flag requested by the user from the clients provider. + + :param flag_type: the type of the flag being returned + :param flag_key: the string key of the selected flag + :param default_value: backup value returned if no result found by the provider + :param evaluation_context: Information for the purposes of flag evaluation + :param flag_evaluation_options: Additional flag evaluation information + :return: a FlagEvaluationDetails object with the fully evaluated flag from a + provider + """ + + if evaluation_context is None: + evaluation_context = EvaluationContext() + + if flag_evaluation_options is None: + flag_evaluation_options = FlagEvaluationOptions() + + provider = self.provider # call this once to maintain a consistent reference + evaluation_hooks = flag_evaluation_options.hooks + hook_hints = flag_evaluation_options.hook_hints + + hook_context = HookContext( + flag_key=flag_key, + flag_type=flag_type, + default_value=default_value, + evaluation_context=evaluation_context, + client_metadata=self.get_metadata(), + provider_metadata=provider.get_metadata(), + ) + # Hooks need to be handled in different orders at different stages + # in the flag evaluation + # before: API, Client, Invocation, Provider + merged_hooks = ( + api.get_hooks() + + self.hooks + + evaluation_hooks + + provider.get_provider_hooks() + ) + # after, error, finally: Provider, Invocation, Client, API + reversed_merged_hooks = merged_hooks[:] + reversed_merged_hooks.reverse() + + status = self.get_provider_status() + if status == ProviderStatus.NOT_READY: + error_hooks( + flag_type, + hook_context, + ProviderNotReadyError(), + reversed_merged_hooks, + hook_hints, + ) + return FlagEvaluationDetails( + flag_key=flag_key, + value=default_value, + reason=Reason.ERROR, + error_code=ErrorCode.PROVIDER_NOT_READY, + ) + if status == ProviderStatus.FATAL: + error_hooks( + flag_type, + hook_context, + ProviderFatalError(), + reversed_merged_hooks, + hook_hints, + ) + return FlagEvaluationDetails( + flag_key=flag_key, + value=default_value, + reason=Reason.ERROR, + error_code=ErrorCode.PROVIDER_FATAL, + ) + + try: + # https://github.com/open-feature/spec/blob/main/specification/sections/03-evaluation-context.md + # Any resulting evaluation context from a before hook will overwrite + # duplicate fields defined globally, on the client, or in the invocation. + # Requirement 3.2.2, 4.3.4: API.context->client.context->invocation.context + invocation_context = before_hooks( + flag_type, hook_context, merged_hooks, hook_hints + ) + invocation_context = invocation_context.merge(ctx2=evaluation_context) + + # Requirement 3.2.2 merge: API.context->client.context->invocation.context + merged_context = ( + api.get_evaluation_context() + .merge(self.context) + .merge(invocation_context) + ) + + flag_evaluation = await self._create_provider_evaluation( + provider, + flag_type, + flag_key, + default_value, + merged_context, + ) + + after_hooks( + flag_type, + hook_context, + flag_evaluation, + reversed_merged_hooks, + hook_hints, + ) + + return flag_evaluation + + except OpenFeatureError as err: + error_hooks(flag_type, hook_context, err, reversed_merged_hooks, hook_hints) + + return FlagEvaluationDetails( + flag_key=flag_key, + value=default_value, + reason=Reason.ERROR, + error_code=err.error_code, + error_message=err.error_message, + ) + # Catch any type of exception here since the user can provide any exception + # in the error hooks + except Exception as err: # pragma: no cover + logger.exception( + "Unable to correctly evaluate flag with key: '%s'", flag_key + ) + + error_hooks(flag_type, hook_context, err, reversed_merged_hooks, hook_hints) + + error_message = getattr(err, "error_message", str(err)) + return FlagEvaluationDetails( + flag_key=flag_key, + value=default_value, + reason=Reason.ERROR, + error_code=ErrorCode.GENERAL, + error_message=error_message, + ) + + finally: + after_all_hooks(flag_type, hook_context, reversed_merged_hooks, hook_hints) + + def _create_provider_evaluation( self, provider: FeatureProvider, @@ -443,6 +736,60 @@ def _create_provider_evaluation( error_message=resolution.error_message, ) + async def _create_provider_evaluation_async( + self, + provider: FeatureProvider, + flag_type: FlagType, + flag_key: str, + default_value: typing.Any, + evaluation_context: typing.Optional[EvaluationContext] = None, + ) -> FlagEvaluationDetails[typing.Any]: + """ + Asynchronous encapsulated method to create a FlagEvaluationDetail from a specific provider. + + :param flag_type: the type of the flag being returned + :param key: the string key of the selected flag + :param default_value: backup value returned if no result found by the provider + :param evaluation_context: Information for the purposes of flag evaluation + :return: a FlagEvaluationDetails object with the fully evaluated flag from a + provider + """ + args = ( + flag_key, + default_value, + evaluation_context, + ) + + get_details_callables: typing.Mapping[FlagType, GetDetailCallable] = { + FlagType.BOOLEAN: provider.resolve_boolean_details_async, + FlagType.INTEGER: provider.resolve_integer_details_async, + FlagType.FLOAT: provider.resolve_float_details_async, + FlagType.OBJECT: provider.resolve_object_details_async, + FlagType.STRING: provider.resolve_string_details_async, + } + + get_details_callable = get_details_callables.get(flag_type) + if not get_details_callable: + raise GeneralError(error_message="Unknown flag type") + + resolution = await get_details_callable(*args) + resolution.raise_for_error() + + # we need to check the get_args to be compatible with union types. + _typecheck_flag_value(resolution.value, flag_type) + + return FlagEvaluationDetails( + flag_key=flag_key, + value=resolution.value, + variant=resolution.variant, + flag_metadata=resolution.flag_metadata or {}, + reason=resolution.reason, + error_code=resolution.error_code, + error_message=resolution.error_message, + ) + + + def add_handler(self, event: ProviderEvent, handler: EventHandler) -> None: _event_support.add_client_handler(self, event, handler) diff --git a/openfeature/provider/__init__.py b/openfeature/provider/__init__.py index 8927551e..903d445f 100644 --- a/openfeature/provider/__init__.py +++ b/openfeature/provider/__init__.py @@ -47,6 +47,13 @@ def resolve_boolean_details( evaluation_context: typing.Optional[EvaluationContext] = None, ) -> FlagResolutionDetails[bool]: ... + async def resolve_boolean_details_async( + self, + flag_key: str, + default_value: bool, + evaluation_context: typing.Optional[EvaluationContext] = None, + ) -> FlagResolutionDetails[bool]: ... + def resolve_string_details( self, flag_key: str, @@ -54,6 +61,13 @@ def resolve_string_details( evaluation_context: typing.Optional[EvaluationContext] = None, ) -> FlagResolutionDetails[str]: ... + async def resolve_string_details_async( + self, + flag_key: str, + default_value: str, + evaluation_context: typing.Optional[EvaluationContext] = None, + ) -> FlagResolutionDetails[str]: ... + def resolve_integer_details( self, flag_key: str, @@ -61,6 +75,13 @@ def resolve_integer_details( evaluation_context: typing.Optional[EvaluationContext] = None, ) -> FlagResolutionDetails[int]: ... + async def resolve_integer_details_async( + self, + flag_key: str, + default_value: int, + evaluation_context: typing.Optional[EvaluationContext] = None, + ) -> FlagResolutionDetails[int]: ... + def resolve_float_details( self, flag_key: str, @@ -68,6 +89,13 @@ def resolve_float_details( evaluation_context: typing.Optional[EvaluationContext] = None, ) -> FlagResolutionDetails[float]: ... + async def resolve_float_details_async( + self, + flag_key: str, + default_value: float, + evaluation_context: typing.Optional[EvaluationContext] = None, + ) -> FlagResolutionDetails[float]: ... + def resolve_object_details( self, flag_key: str, @@ -75,6 +103,13 @@ def resolve_object_details( evaluation_context: typing.Optional[EvaluationContext] = None, ) -> FlagResolutionDetails[typing.Union[dict, list]]: ... + async def resolve_object_details_async( + self, + flag_key: str, + default_value: typing.Union[dict, list], + evaluation_context: typing.Optional[EvaluationContext] = None, + ) -> FlagResolutionDetails[typing.Union[dict, list]]: ... + class AbstractProvider(FeatureProvider): def attach( @@ -111,6 +146,14 @@ def resolve_boolean_details( ) -> FlagResolutionDetails[bool]: pass + async def resolve_boolean_details_async( + self, + flag_key: str, + default_value: bool, + evaluation_context: typing.Optional[EvaluationContext] = None, + ) -> FlagResolutionDetails[bool]: + raise NotImplementedError(f"{self.__class__.__name__} does not support async operations") + @abstractmethod def resolve_string_details( self, @@ -120,6 +163,14 @@ def resolve_string_details( ) -> FlagResolutionDetails[str]: pass + async def resolve_string_details_async( + self, + flag_key: str, + default_value: str, + evaluation_context: typing.Optional[EvaluationContext] = None, + ) -> FlagResolutionDetails[str]: + raise NotImplementedError(f"{self.__class__.__name__} does not support async operations") + @abstractmethod def resolve_integer_details( self, @@ -129,6 +180,14 @@ def resolve_integer_details( ) -> FlagResolutionDetails[int]: pass + async def resolve_integer_details_async( + self, + flag_key: str, + default_value: int, + evaluation_context: typing.Optional[EvaluationContext] = None, + ) -> FlagResolutionDetails[int]: + raise NotImplementedError(f"{self.__class__.__name__} does not support async operations") + @abstractmethod def resolve_float_details( self, @@ -138,6 +197,14 @@ def resolve_float_details( ) -> FlagResolutionDetails[float]: pass + async def resolve_float_details_async( + self, + flag_key: str, + default_value: float, + evaluation_context: typing.Optional[EvaluationContext] = None, + ) -> FlagResolutionDetails[float]: + raise NotImplementedError(f"{self.__class__.__name__} does not support async operations") + @abstractmethod def resolve_object_details( self, @@ -147,6 +214,14 @@ def resolve_object_details( ) -> FlagResolutionDetails[typing.Union[dict, list]]: pass + async def resolve_object_details_async( + self, + flag_key: str, + default_value: typing.Union[dict, list], + evaluation_context: typing.Optional[EvaluationContext] = None, + ) -> FlagResolutionDetails[typing.Union[dict, list]]: + raise NotImplementedError(f"{self.__class__.__name__} does not support async operations") + def emit_provider_ready(self, details: ProviderEventDetails) -> None: self.emit(ProviderEvent.PROVIDER_READY, details) From eaf53c12840c4c4df071d4cf7304c2453c7474d4 Mon Sep 17 00:00:00 2001 From: leohoare Date: Thu, 7 Nov 2024 17:15:06 +1100 Subject: [PATCH 2/6] add unit tests Signed-off-by: leohoare --- pyproject.toml | 1 + tests/test_client.py | 92 ++++++++++++++++++++++++++++++++++++++++++++ 2 files changed, 93 insertions(+) diff --git a/pyproject.toml b/pyproject.toml index 40887d9d..f17eaba6 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -34,6 +34,7 @@ dependencies = [ "behave", "coverage[toml]>=6.5", "pytest", + "pytest-asyncio", ] [tool.hatch.envs.default.scripts] diff --git a/tests/test_client.py b/tests/test_client.py index b51c460c..a98f26e2 100644 --- a/tests/test_client.py +++ b/tests/test_client.py @@ -384,3 +384,95 @@ def emit_events_task(): f2 = executor.submit(emit_events_task) f1.result() f2.result() + +@pytest.mark.asyncio +async def test_evaluate_boolean_flag_details_async(): + # Given + provider = MagicMock(spec=FeatureProvider) + provider.resolve_boolean_details_async.return_value = FlagResolutionDetails( + value=True, + reason=Reason.TARGETING_MATCH, + ) + set_provider(provider) + client = get_client() + # When + flag = await client.evaluate_flag_details_async( + flag_type=bool, flag_key="Key", default_value=True + ) + + # Then + assert flag is not None + assert flag.value == True + +@pytest.mark.asyncio +async def test_evaluate_string_flag_details_async(): + # Given + provider = MagicMock(spec=FeatureProvider) + provider.resolve_string_details_async.return_value = FlagResolutionDetails( + value="String", + reason=Reason.TARGETING_MATCH, + ) + set_provider(provider) + client = get_client() + # When + flag = await client.evaluate_flag_details_async( + flag_type=str, flag_key="Key", default_value="String" + ) + + # Then + assert flag is not None + assert flag.value == "String" + +@pytest.mark.asyncio +async def test_evaluate_integer_flag_details_async(): + # Given + provider = MagicMock(spec=FeatureProvider) + provider.resolve_integer_details_async.return_value = FlagResolutionDetails( + value=100, + reason=Reason.TARGETING_MATCH, + ) + set_provider(provider) + client = get_client() + # When + flag = await client.evaluate_flag_details_async( + flag_type=int, flag_key="Key", default_value=100 + ) + + # Then + assert flag is not None + assert flag.value == 100 + +@pytest.mark.asyncio +async def test_evaluate_float_flag_details_async(): + # Given + provider = MagicMock(spec=FeatureProvider) + provider.resolve_float_details_async.return_value = FlagResolutionDetails( + value=10.23, + reason=Reason.TARGETING_MATCH, + ) + set_provider(provider) + client = get_client() + # When + flag = await client.evaluate_flag_details_async( + flag_type=float, flag_key="Key", default_value=10.23 + ) + + # Then + assert flag is not None + assert flag.value == 10.23 + + +@pytest.mark.asyncio +async def test_allow_not_implemented_async_functions(): + # Given + provider = NoOpProvider() + set_provider(provider) + # When + with pytest.raises(NotImplementedError) as exc_info: + flag = await provider.resolve_boolean_details_async( + flag_key="Key", default_value=True + ) + raise Exception(flag) + + # Then + assert "does not support async operations" in str(exc_info.value) From 5f1b1517443c09f6022636781b763682c1e03b0a Mon Sep 17 00:00:00 2001 From: leohoare Date: Fri, 8 Nov 2024 13:48:59 +1100 Subject: [PATCH 3/6] fork into separate provider Signed-off-by: leohoare --- openfeature/api.py | 4 +- openfeature/client.py | 488 +++++++++++---------- openfeature/hook/__init__.py | 63 +++ openfeature/hook/_hook_support.py | 116 +++++ openfeature/provider/__init__.py | 125 ++++-- openfeature/provider/in_memory_provider.py | 53 +++ tests/test_async_client.py | 183 ++++++++ tests/test_client.py | 92 ---- 8 files changed, 770 insertions(+), 354 deletions(-) create mode 100644 tests/test_async_client.py diff --git a/openfeature/api.py b/openfeature/api.py index c95d10ac..db82d016 100644 --- a/openfeature/api.py +++ b/openfeature/api.py @@ -1,7 +1,7 @@ import typing from openfeature import _event_support -from openfeature.client import OpenFeatureClient +from openfeature.client import OpenFeatureClient, AsyncOpenFeatureClient from openfeature.evaluation_context import EvaluationContext from openfeature.event import ( EventHandler, @@ -38,6 +38,8 @@ def get_client( ) -> OpenFeatureClient: return OpenFeatureClient(domain=domain, version=version) +def get_client_async(domain: typing.Optional[str] = None, version: typing.Optional[str] = None) -> OpenFeatureClient: + return AsyncOpenFeatureClient(domain=domain, version=version) def set_provider( provider: FeatureProvider, domain: typing.Optional[str] = None diff --git a/openfeature/client.py b/openfeature/client.py index c61425ef..d22a3540 100644 --- a/openfeature/client.py +++ b/openfeature/client.py @@ -23,9 +23,13 @@ from openfeature.hook import Hook, HookContext from openfeature.hook._hook_support import ( after_all_hooks, + after_all_hooks_async, after_hooks, + after_hooks_async, before_hooks, + before_hooks_async, error_hooks, + error_hooks_async, ) from openfeature.provider import FeatureProvider, ProviderStatus from openfeature.provider._registry import provider_registry @@ -33,6 +37,7 @@ __all__ = [ "ClientMetadata", "OpenFeatureClient", + "AsyncOpenFeatureClient", ] logger = logging.getLogger("openfeature") @@ -112,20 +117,6 @@ def get_boolean_value( evaluation_context, flag_evaluation_options, ).value - - async def get_boolean_value_async( - self, - flag_key: str, - default_value: bool, - evaluation_context: typing.Optional[EvaluationContext] = None, - flag_evaluation_options: typing.Optional[FlagEvaluationOptions] = None, - ) -> bool: - return await self.get_boolean_details_async( - flag_key, - default_value, - evaluation_context, - flag_evaluation_options, - ).value def get_boolean_details( self, @@ -142,21 +133,6 @@ def get_boolean_details( flag_evaluation_options, ) - async def get_boolean_details_async( - self, - flag_key: str, - default_value: bool, - evaluation_context: typing.Optional[EvaluationContext] = None, - flag_evaluation_options: typing.Optional[FlagEvaluationOptions] = None, - ) -> FlagEvaluationDetails[bool]: - return await self.evaluate_flag_details_async( - FlagType.BOOLEAN, - flag_key, - default_value, - evaluation_context, - flag_evaluation_options, - ) - def get_string_value( self, flag_key: str, @@ -171,20 +147,6 @@ def get_string_value( flag_evaluation_options, ).value - async def get_string_value_async( - self, - flag_key: str, - default_value: str, - evaluation_context: typing.Optional[EvaluationContext] = None, - flag_evaluation_options: typing.Optional[FlagEvaluationOptions] = None, - ) -> str: - return await self.get_string_details_async( - flag_key, - default_value, - evaluation_context, - flag_evaluation_options, - ).value - def get_string_details( self, flag_key: str, @@ -199,21 +161,6 @@ def get_string_details( evaluation_context, flag_evaluation_options, ) - - async def get_string_details_async( - self, - flag_key: str, - default_value: str, - evaluation_context: typing.Optional[EvaluationContext] = None, - flag_evaluation_options: typing.Optional[FlagEvaluationOptions] = None, - ) -> FlagEvaluationDetails[str]: - return await self.evaluate_flag_details_async( - FlagType.STRING, - flag_key, - default_value, - evaluation_context, - flag_evaluation_options, - ) def get_integer_value( self, @@ -229,20 +176,6 @@ def get_integer_value( flag_evaluation_options, ).value - async def get_integer_value_async( - self, - flag_key: str, - default_value: int, - evaluation_context: typing.Optional[EvaluationContext] = None, - flag_evaluation_options: typing.Optional[FlagEvaluationOptions] = None, - ) -> int: - return await self.get_integer_details_async( - flag_key, - default_value, - evaluation_context, - flag_evaluation_options, - ).value - def get_integer_details( self, flag_key: str, @@ -258,21 +191,6 @@ def get_integer_details( flag_evaluation_options, ) - async def get_integer_details_async( - self, - flag_key: str, - default_value: int, - evaluation_context: typing.Optional[EvaluationContext] = None, - flag_evaluation_options: typing.Optional[FlagEvaluationOptions] = None, - ) -> FlagEvaluationDetails[int]: - return await self.evaluate_flag_details_async( - FlagType.INTEGER, - flag_key, - default_value, - evaluation_context, - flag_evaluation_options, - ) - def get_float_value( self, flag_key: str, @@ -287,20 +205,6 @@ def get_float_value( flag_evaluation_options, ).value - async def get_float_value_async( - self, - flag_key: str, - default_value: float, - evaluation_context: typing.Optional[EvaluationContext] = None, - flag_evaluation_options: typing.Optional[FlagEvaluationOptions] = None, - ) -> float: - return await self.get_float_details_async( - flag_key, - default_value, - evaluation_context, - flag_evaluation_options, - ).value - def get_float_details( self, flag_key: str, @@ -315,21 +219,6 @@ def get_float_details( evaluation_context, flag_evaluation_options, ) - - async def get_float_details_async( - self, - flag_key: str, - default_value: float, - evaluation_context: typing.Optional[EvaluationContext] = None, - flag_evaluation_options: typing.Optional[FlagEvaluationOptions] = None, - ) -> FlagEvaluationDetails[float]: - return await self.evaluate_flag_details_async( - FlagType.FLOAT, - flag_key, - default_value, - evaluation_context, - flag_evaluation_options, - ) def get_object_value( self, @@ -344,20 +233,6 @@ def get_object_value( evaluation_context, flag_evaluation_options, ).value - - async def get_object_value_async( - self, - flag_key: str, - default_value: typing.Union[dict, list], - evaluation_context: typing.Optional[EvaluationContext] = None, - flag_evaluation_options: typing.Optional[FlagEvaluationOptions] = None, - ) -> typing.Union[dict, list]: - return await self.get_object_details_async( - flag_key, - default_value, - evaluation_context, - flag_evaluation_options, - ).value def get_object_details( self, @@ -373,21 +248,6 @@ def get_object_details( evaluation_context, flag_evaluation_options, ) - - async def get_object_details_async( - self, - flag_key: str, - default_value: typing.Union[dict, list], - evaluation_context: typing.Optional[EvaluationContext] = None, - flag_evaluation_options: typing.Optional[FlagEvaluationOptions] = None, - ) -> FlagEvaluationDetails[typing.Union[dict, list]]: - return await self.evaluate_flag_details_async( - FlagType.OBJECT, - flag_key, - default_value, - evaluation_context, - flag_evaluation_options, - ) def evaluate_flag_details( # noqa: PLR0915 self, @@ -536,7 +396,252 @@ def evaluate_flag_details( # noqa: PLR0915 finally: after_all_hooks(flag_type, hook_context, reversed_merged_hooks, hook_hints) - async def evaluate_flag_details_async( # noqa: PLR0915 + def _create_provider_evaluation( + self, + provider: FeatureProvider, + flag_type: FlagType, + flag_key: str, + default_value: typing.Any, + evaluation_context: typing.Optional[EvaluationContext] = None, + ) -> FlagEvaluationDetails[typing.Any]: + """ + Encapsulated method to create a FlagEvaluationDetail from a specific provider. + + :param flag_type: the type of the flag being returned + :param key: the string key of the selected flag + :param default_value: backup value returned if no result found by the provider + :param evaluation_context: Information for the purposes of flag evaluation + :return: a FlagEvaluationDetails object with the fully evaluated flag from a + provider + """ + args = ( + flag_key, + default_value, + evaluation_context, + ) + + get_details_callables: typing.Mapping[FlagType, GetDetailCallable] = { + FlagType.BOOLEAN: provider.resolve_boolean_details, + FlagType.INTEGER: provider.resolve_integer_details, + FlagType.FLOAT: provider.resolve_float_details, + FlagType.OBJECT: provider.resolve_object_details, + FlagType.STRING: provider.resolve_string_details, + } + + get_details_callable = get_details_callables.get(flag_type) + if not get_details_callable: + raise GeneralError(error_message="Unknown flag type") + + resolution = get_details_callable(*args) + resolution.raise_for_error() + + # we need to check the get_args to be compatible with union types. + _typecheck_flag_value(resolution.value, flag_type) + + return FlagEvaluationDetails( + flag_key=flag_key, + value=resolution.value, + variant=resolution.variant, + flag_metadata=resolution.flag_metadata or {}, + reason=resolution.reason, + error_code=resolution.error_code, + error_message=resolution.error_message, + ) + + + def add_handler(self, event: ProviderEvent, handler: EventHandler) -> None: + _event_support.add_client_handler(self, event, handler) + + def remove_handler(self, event: ProviderEvent, handler: EventHandler) -> None: + _event_support.remove_client_handler(self, event, handler) + + +def _typecheck_flag_value(value: typing.Any, flag_type: FlagType) -> None: + type_map: TypeMap = { + FlagType.BOOLEAN: bool, + FlagType.STRING: str, + FlagType.OBJECT: (dict, list), + FlagType.FLOAT: float, + FlagType.INTEGER: int, + } + _type = type_map.get(flag_type) + if not _type: + raise GeneralError(error_message="Unknown flag type") + if not isinstance(value, _type): + raise TypeMismatchError(f"Expected type {_type} but got {type(value)}") + +class AsyncOpenFeatureClient: + def __init__( + self, + domain: typing.Optional[str], + version: typing.Optional[str], + context: typing.Optional[EvaluationContext] = None, + hooks: typing.Optional[typing.List[Hook]] = None, + ) -> None: + self.domain = domain + self.version = version + self.context = context or EvaluationContext() + self.hooks = hooks or [] + + @property + def provider(self) -> FeatureProvider: + return provider_registry.get_provider(self.domain) + + def get_provider_status(self) -> ProviderStatus: + return provider_registry.get_provider_status(self.provider) + + def get_metadata(self) -> ClientMetadata: + return ClientMetadata(domain=self.domain) + + def add_hooks(self, hooks: typing.List[Hook]) -> None: + self.hooks = self.hooks + hooks + + async def get_boolean_value( + self, + flag_key: str, + default_value: bool, + evaluation_context: typing.Optional[EvaluationContext] = None, + flag_evaluation_options: typing.Optional[FlagEvaluationOptions] = None, + ) -> bool: + return await self.get_boolean_value( + flag_key, + default_value, + evaluation_context, + flag_evaluation_options, + ) + + async def get_boolean_details( + self, + flag_key: str, + default_value: bool, + evaluation_context: typing.Optional[EvaluationContext] = None, + flag_evaluation_options: typing.Optional[FlagEvaluationOptions] = None, + ) -> FlagEvaluationDetails[bool]: + return await self.evaluate_flag_details( + FlagType.BOOLEAN, + flag_key, + default_value, + evaluation_context, + flag_evaluation_options, + ) + + async def get_string_value( + self, + flag_key: str, + default_value: str, + evaluation_context: typing.Optional[EvaluationContext] = None, + flag_evaluation_options: typing.Optional[FlagEvaluationOptions] = None, + ) -> str: + return await self.get_string_details( + flag_key, + default_value, + evaluation_context, + flag_evaluation_options, + ).value + + async def get_string_details( + self, + flag_key: str, + default_value: str, + evaluation_context: typing.Optional[EvaluationContext] = None, + flag_evaluation_options: typing.Optional[FlagEvaluationOptions] = None, + ) -> FlagEvaluationDetails[str]: + return await self.evaluate_flag_details( + FlagType.STRING, + flag_key, + default_value, + evaluation_context, + flag_evaluation_options, + ) + + async def get_integer_value( + self, + flag_key: str, + default_value: int, + evaluation_context: typing.Optional[EvaluationContext] = None, + flag_evaluation_options: typing.Optional[FlagEvaluationOptions] = None, + ) -> int: + return await self.get_integer_details( + flag_key, + default_value, + evaluation_context, + flag_evaluation_options, + ).value + + async def get_integer_details( + self, + flag_key: str, + default_value: int, + evaluation_context: typing.Optional[EvaluationContext] = None, + flag_evaluation_options: typing.Optional[FlagEvaluationOptions] = None, + ) -> FlagEvaluationDetails[int]: + return await self.evaluate_flag_details( + FlagType.INTEGER, + flag_key, + default_value, + evaluation_context, + flag_evaluation_options, + ) + + async def get_float_value( + self, + flag_key: str, + default_value: float, + evaluation_context: typing.Optional[EvaluationContext] = None, + flag_evaluation_options: typing.Optional[FlagEvaluationOptions] = None, + ) -> float: + return await self.get_float_details( + flag_key, + default_value, + evaluation_context, + flag_evaluation_options, + ).value + + async def get_float_details( + self, + flag_key: str, + default_value: float, + evaluation_context: typing.Optional[EvaluationContext] = None, + flag_evaluation_options: typing.Optional[FlagEvaluationOptions] = None, + ) -> FlagEvaluationDetails[float]: + return await self.evaluate_flag_details( + FlagType.FLOAT, + flag_key, + default_value, + evaluation_context, + flag_evaluation_options, + ) + + async def get_object_value( + self, + flag_key: str, + default_value: typing.Union[dict, list], + evaluation_context: typing.Optional[EvaluationContext] = None, + flag_evaluation_options: typing.Optional[FlagEvaluationOptions] = None, + ) -> typing.Union[dict, list]: + return await self.get_object_details( + flag_key, + default_value, + evaluation_context, + flag_evaluation_options, + ).value + + async def get_object_details( + self, + flag_key: str, + default_value: typing.Union[dict, list], + evaluation_context: typing.Optional[EvaluationContext] = None, + flag_evaluation_options: typing.Optional[FlagEvaluationOptions] = None, + ) -> FlagEvaluationDetails[typing.Union[dict, list]]: + return await self.evaluate_flag_details( + FlagType.OBJECT, + flag_key, + default_value, + evaluation_context, + flag_evaluation_options, + ) + + async def evaluate_flag_details( # noqa: PLR0915 self, flag_type: FlagType, flag_key: str, @@ -555,7 +660,6 @@ async def evaluate_flag_details_async( # noqa: PLR0915 :return: a FlagEvaluationDetails object with the fully evaluated flag from a provider """ - if evaluation_context is None: evaluation_context = EvaluationContext() @@ -589,7 +693,7 @@ async def evaluate_flag_details_async( # noqa: PLR0915 status = self.get_provider_status() if status == ProviderStatus.NOT_READY: - error_hooks( + await error_hooks_async( flag_type, hook_context, ProviderNotReadyError(), @@ -603,7 +707,7 @@ async def evaluate_flag_details_async( # noqa: PLR0915 error_code=ErrorCode.PROVIDER_NOT_READY, ) if status == ProviderStatus.FATAL: - error_hooks( + await error_hooks_async( flag_type, hook_context, ProviderFatalError(), @@ -622,11 +726,11 @@ async def evaluate_flag_details_async( # noqa: PLR0915 # Any resulting evaluation context from a before hook will overwrite # duplicate fields defined globally, on the client, or in the invocation. # Requirement 3.2.2, 4.3.4: API.context->client.context->invocation.context - invocation_context = before_hooks( + invocation_context = await before_hooks_async( flag_type, hook_context, merged_hooks, hook_hints ) - invocation_context = invocation_context.merge(ctx2=evaluation_context) + invocation_context = invocation_context.merge(ctx2=evaluation_context) # Requirement 3.2.2 merge: API.context->client.context->invocation.context merged_context = ( api.get_evaluation_context() @@ -642,7 +746,7 @@ async def evaluate_flag_details_async( # noqa: PLR0915 merged_context, ) - after_hooks( + await after_hooks_async( flag_type, hook_context, flag_evaluation, @@ -653,7 +757,7 @@ async def evaluate_flag_details_async( # noqa: PLR0915 return flag_evaluation except OpenFeatureError as err: - error_hooks(flag_type, hook_context, err, reversed_merged_hooks, hook_hints) + await error_hooks_async(flag_type, hook_context, err, reversed_merged_hooks, hook_hints) return FlagEvaluationDetails( flag_key=flag_key, @@ -669,7 +773,7 @@ async def evaluate_flag_details_async( # noqa: PLR0915 "Unable to correctly evaluate flag with key: '%s'", flag_key ) - error_hooks(flag_type, hook_context, err, reversed_merged_hooks, hook_hints) + await error_hooks_async(flag_type, hook_context, err, reversed_merged_hooks, hook_hints) error_message = getattr(err, "error_message", str(err)) return FlagEvaluationDetails( @@ -681,10 +785,10 @@ async def evaluate_flag_details_async( # noqa: PLR0915 ) finally: - after_all_hooks(flag_type, hook_context, reversed_merged_hooks, hook_hints) + await after_all_hooks_async(flag_type, hook_context, reversed_merged_hooks, hook_hints) - def _create_provider_evaluation( + async def _create_provider_evaluation( self, provider: FeatureProvider, flag_type: FlagType, @@ -693,7 +797,7 @@ def _create_provider_evaluation( evaluation_context: typing.Optional[EvaluationContext] = None, ) -> FlagEvaluationDetails[typing.Any]: """ - Encapsulated method to create a FlagEvaluationDetail from a specific provider. + Asynchronous encapsulated method to create a FlagEvaluationDetail from a specific provider. :param flag_type: the type of the flag being returned :param key: the string key of the selected flag @@ -716,58 +820,6 @@ def _create_provider_evaluation( FlagType.STRING: provider.resolve_string_details, } - get_details_callable = get_details_callables.get(flag_type) - if not get_details_callable: - raise GeneralError(error_message="Unknown flag type") - - resolution = get_details_callable(*args) - resolution.raise_for_error() - - # we need to check the get_args to be compatible with union types. - _typecheck_flag_value(resolution.value, flag_type) - - return FlagEvaluationDetails( - flag_key=flag_key, - value=resolution.value, - variant=resolution.variant, - flag_metadata=resolution.flag_metadata or {}, - reason=resolution.reason, - error_code=resolution.error_code, - error_message=resolution.error_message, - ) - - async def _create_provider_evaluation_async( - self, - provider: FeatureProvider, - flag_type: FlagType, - flag_key: str, - default_value: typing.Any, - evaluation_context: typing.Optional[EvaluationContext] = None, - ) -> FlagEvaluationDetails[typing.Any]: - """ - Asynchronous encapsulated method to create a FlagEvaluationDetail from a specific provider. - - :param flag_type: the type of the flag being returned - :param key: the string key of the selected flag - :param default_value: backup value returned if no result found by the provider - :param evaluation_context: Information for the purposes of flag evaluation - :return: a FlagEvaluationDetails object with the fully evaluated flag from a - provider - """ - args = ( - flag_key, - default_value, - evaluation_context, - ) - - get_details_callables: typing.Mapping[FlagType, GetDetailCallable] = { - FlagType.BOOLEAN: provider.resolve_boolean_details_async, - FlagType.INTEGER: provider.resolve_integer_details_async, - FlagType.FLOAT: provider.resolve_float_details_async, - FlagType.OBJECT: provider.resolve_object_details_async, - FlagType.STRING: provider.resolve_string_details_async, - } - get_details_callable = get_details_callables.get(flag_type) if not get_details_callable: raise GeneralError(error_message="Unknown flag type") @@ -789,24 +841,8 @@ async def _create_provider_evaluation_async( ) - def add_handler(self, event: ProviderEvent, handler: EventHandler) -> None: _event_support.add_client_handler(self, event, handler) def remove_handler(self, event: ProviderEvent, handler: EventHandler) -> None: - _event_support.remove_client_handler(self, event, handler) - - -def _typecheck_flag_value(value: typing.Any, flag_type: FlagType) -> None: - type_map: TypeMap = { - FlagType.BOOLEAN: bool, - FlagType.STRING: str, - FlagType.OBJECT: (dict, list), - FlagType.FLOAT: float, - FlagType.INTEGER: int, - } - _type = type_map.get(flag_type) - if not _type: - raise GeneralError(error_message="Unknown flag type") - if not isinstance(value, _type): - raise TypeMismatchError(f"Expected type {_type} but got {type(value)}") + _event_support.remove_client_handler(self, event, handler) \ No newline at end of file diff --git a/openfeature/hook/__init__.py b/openfeature/hook/__init__.py index e98301fa..546d8cab 100644 --- a/openfeature/hook/__init__.py +++ b/openfeature/hook/__init__.py @@ -128,3 +128,66 @@ def supports_flag_value_type(self, flag_type: FlagType) -> bool: or not (False) """ return True + +class AsyncHook: + async def before( + self, hook_context: HookContext, hints: HookHints + ) -> typing.Optional[EvaluationContext]: + """ + Runs before flag is resolved. + + :param hook_context: Information about the particular flag evaluation + :param hints: An immutable mapping of data for users to + communicate to the hooks. + :return: An EvaluationContext. It will be merged with the + EvaluationContext instances from other hooks, the client and API. + """ + return None + + async def after( + self, + hook_context: HookContext, + details: FlagEvaluationDetails[typing.Any], + hints: HookHints, + ) -> None: + """ + Runs after a flag is resolved. + + :param hook_context: Information about the particular flag evaluation + :param details: Information about how the flag was resolved, + including any resolved values. + :param hints: A mapping of data for users to communicate to the hooks. + """ + pass + + async def error( + self, hook_context: HookContext, exception: Exception, hints: HookHints + ) -> None: + """ + Run when evaluation encounters an error. Errors thrown will be swallowed. + + :param hook_context: Information about the particular flag evaluation + :param exception: The exception that was thrown + :param hints: A mapping of data for users to communicate to the hooks. + """ + pass + + async def finally_after(self, hook_context: HookContext, hints: HookHints) -> None: + """ + Run after flag evaluation, including any error processing. + This will always run. Errors will be swallowed. + + :param hook_context: Information about the particular flag evaluation + :param hints: A mapping of data for users to communicate to the hooks. + """ + pass + + def supports_flag_value_type(self, flag_type: FlagType) -> bool: + """ + Check to see if the hook supports the particular flag type. + + :param flag_type: particular type of the flag + :return: a boolean containing whether the flag type is supported (True) + or not (False) + """ + return True diff --git a/openfeature/hook/_hook_support.py b/openfeature/hook/_hook_support.py index 349b25f3..c36ebedc 100644 --- a/openfeature/hook/_hook_support.py +++ b/openfeature/hook/_hook_support.py @@ -21,6 +21,17 @@ def error_hooks( flag_type=flag_type, hooks=hooks, hook_method=HookType.ERROR, **kwargs ) +async def error_hooks_async( + flag_type: FlagType, + hook_context: HookContext, + exception: Exception, + hooks: typing.List[Hook], + hints: typing.Optional[HookHints] = None, +) -> None: + kwargs = {"hook_context": hook_context, "exception": exception, "hints": hints} + await _execute_hooks_async( + flag_type=flag_type, hooks=hooks, hook_method=HookType.ERROR, **kwargs + ) def after_all_hooks( flag_type: FlagType, @@ -33,6 +44,17 @@ def after_all_hooks( flag_type=flag_type, hooks=hooks, hook_method=HookType.FINALLY_AFTER, **kwargs ) +async def after_all_hooks_async( + flag_type: FlagType, + hook_context: HookContext, + hooks: typing.List[Hook], + hints: typing.Optional[HookHints] = None, +) -> None: + kwargs = {"hook_context": hook_context, "hints": hints} + await _execute_hooks_async( + flag_type=flag_type, hooks=hooks, hook_method=HookType.FINALLY_AFTER, **kwargs + ) + def after_hooks( flag_type: FlagType, @@ -46,6 +68,18 @@ def after_hooks( flag_type=flag_type, hooks=hooks, hook_method=HookType.AFTER, **kwargs ) +async def after_hooks_async( + flag_type: FlagType, + hook_context: HookContext, + details: FlagEvaluationDetails[typing.Any], + hooks: typing.List[Hook], + hints: typing.Optional[HookHints] = None, +) -> None: + kwargs = {"hook_context": hook_context, "details": details, "hints": hints} + await _execute_hooks_async_unchecked( + flag_type=flag_type, hooks=hooks, hook_method=HookType.AFTER, **kwargs + ) + def before_hooks( flag_type: FlagType, @@ -64,6 +98,22 @@ def before_hooks( return EvaluationContext() +async def before_hooks_async( + flag_type: FlagType, + hook_context: HookContext, + hooks: typing.List[Hook], + hints: typing.Optional[HookHints] = None, +) -> EvaluationContext: + kwargs = {"hook_context": hook_context, "hints": hints} + executed_hooks = await _execute_hooks_async( + flag_type=flag_type, hooks=hooks, hook_method=HookType.BEFORE, **kwargs + ) + filtered_hooks = [result for result in executed_hooks if result is not None] + if filtered_hooks: + return reduce(lambda a, b: a.merge(b), filtered_hooks) + + return EvaluationContext() + def _execute_hooks( flag_type: FlagType, @@ -87,6 +137,28 @@ def _execute_hooks( if hook.supports_flag_value_type(flag_type) ] +async def _execute_hooks_async( + flag_type: FlagType, + hooks: typing.List[Hook], + hook_method: HookType, + **kwargs: typing.Any, +) -> list: + """ + Run multiple hooks of any hook type. All of these hooks will be run through an + exception check. + + :param flag_type: particular type of flag + :param hooks: a list of hooks + :param hook_method: the type of hook that is being run + :param kwargs: arguments that need to be provided to the hook method + :return: a list of results from the applied hook methods + """ + return [ + await _execute_hook_checked_async(hook, hook_method, **kwargs) + for hook in hooks + if hook.supports_flag_value_type(flag_type) + ] + def _execute_hooks_unchecked( flag_type: FlagType, @@ -111,6 +183,29 @@ def _execute_hooks_unchecked( if hook.supports_flag_value_type(flag_type) ] +async def _execute_hooks_async_unchecked( + flag_type: FlagType, + hooks: typing.List[Hook], + hook_method: HookType, + **kwargs: typing.Any, +) -> typing.List[typing.Optional[EvaluationContext]]: + """ + Execute a single hook without checking whether an exception is thrown. This is + used in the before and after hooks since any exception will be caught in the + client. + + :param flag_type: particular type of flag + :param hooks: a list of hooks + :param hook_method: the type of hook that is being run + :param kwargs: arguments that need to be provided to the hook method + :return: a list of results from the applied hook methods + """ + return [ + await getattr(hook, hook_method.value)(**kwargs) + for hook in hooks + if hook.supports_flag_value_type(flag_type) + ] + def _execute_hook_checked( hook: Hook, hook_method: HookType, **kwargs: typing.Any @@ -132,3 +227,24 @@ def _execute_hook_checked( except Exception: # pragma: no cover logger.exception(f"Exception when running {hook_method.value} hooks") return None + +async def _execute_hook_checked_async( + hook: Hook, hook_method: HookType, **kwargs: typing.Any +) -> typing.Optional[EvaluationContext]: + """ + Try and run a single hook and catch any exception thrown. This is used in the + after all and error hooks since any error thrown at this point needs to be caught. + + :param hook: a list of hooks + :param hook_method: the type of hook that is being run + :param kwargs: arguments that need to be provided to the hook method + :return: the result of the hook method + """ + try: + return typing.cast( + "typing.Optional[EvaluationContext]", + await getattr(hook, hook_method.value)(**kwargs), + ) + except Exception: # pragma: no cover + logger.exception(f"Exception when running {hook_method.value} hooks") + return None diff --git a/openfeature/provider/__init__.py b/openfeature/provider/__init__.py index 903d445f..4bb30be7 100644 --- a/openfeature/provider/__init__.py +++ b/openfeature/provider/__init__.py @@ -47,13 +47,6 @@ def resolve_boolean_details( evaluation_context: typing.Optional[EvaluationContext] = None, ) -> FlagResolutionDetails[bool]: ... - async def resolve_boolean_details_async( - self, - flag_key: str, - default_value: bool, - evaluation_context: typing.Optional[EvaluationContext] = None, - ) -> FlagResolutionDetails[bool]: ... - def resolve_string_details( self, flag_key: str, @@ -61,13 +54,6 @@ def resolve_string_details( evaluation_context: typing.Optional[EvaluationContext] = None, ) -> FlagResolutionDetails[str]: ... - async def resolve_string_details_async( - self, - flag_key: str, - default_value: str, - evaluation_context: typing.Optional[EvaluationContext] = None, - ) -> FlagResolutionDetails[str]: ... - def resolve_integer_details( self, flag_key: str, @@ -75,13 +61,6 @@ def resolve_integer_details( evaluation_context: typing.Optional[EvaluationContext] = None, ) -> FlagResolutionDetails[int]: ... - async def resolve_integer_details_async( - self, - flag_key: str, - default_value: int, - evaluation_context: typing.Optional[EvaluationContext] = None, - ) -> FlagResolutionDetails[int]: ... - def resolve_float_details( self, flag_key: str, @@ -89,13 +68,6 @@ def resolve_float_details( evaluation_context: typing.Optional[EvaluationContext] = None, ) -> FlagResolutionDetails[float]: ... - async def resolve_float_details_async( - self, - flag_key: str, - default_value: float, - evaluation_context: typing.Optional[EvaluationContext] = None, - ) -> FlagResolutionDetails[float]: ... - def resolve_object_details( self, flag_key: str, @@ -103,13 +75,6 @@ def resolve_object_details( evaluation_context: typing.Optional[EvaluationContext] = None, ) -> FlagResolutionDetails[typing.Union[dict, list]]: ... - async def resolve_object_details_async( - self, - flag_key: str, - default_value: typing.Union[dict, list], - evaluation_context: typing.Optional[EvaluationContext] = None, - ) -> FlagResolutionDetails[typing.Union[dict, list]]: ... - class AbstractProvider(FeatureProvider): def attach( @@ -239,3 +204,93 @@ def emit_provider_stale(self, details: ProviderEventDetails) -> None: def emit(self, event: ProviderEvent, details: ProviderEventDetails) -> None: if hasattr(self, "_on_emit"): self._on_emit(self, event, details) + + +class AsyncFeatureProvider(FeatureProvider): + async def attach( + self, + on_emit: typing.Callable[ + [FeatureProvider, ProviderEvent, ProviderEventDetails], None + ], + ) -> None: + self._on_emit = on_emit + + async def detach(self) -> None: + if hasattr(self, "_on_emit"): + del self._on_emit + + async def initialize(self, evaluation_context: EvaluationContext) -> None: + pass + + async def shutdown(self) -> None: + pass + + @abstractmethod + async def get_metadata(self) -> Metadata: + pass + + async def get_provider_hooks(self) -> typing.List[Hook]: + return [] + + @abstractmethod + async def resolve_boolean_details( + self, + flag_key: str, + default_value: bool, + evaluation_context: typing.Optional[EvaluationContext] = None, + ) -> FlagResolutionDetails[bool]: + pass + + @abstractmethod + async def resolve_string_details( + self, + flag_key: str, + default_value: str, + evaluation_context: typing.Optional[EvaluationContext] = None, + ) -> FlagResolutionDetails[str]: + pass + + @abstractmethod + async def resolve_integer_details( + self, + flag_key: str, + default_value: int, + evaluation_context: typing.Optional[EvaluationContext] = None, + ) -> FlagResolutionDetails[int]: + pass + + @abstractmethod + async def resolve_float_details( + self, + flag_key: str, + default_value: float, + evaluation_context: typing.Optional[EvaluationContext] = None, + ) -> FlagResolutionDetails[float]: + pass + + @abstractmethod + def resolve_object_details( + self, + flag_key: str, + default_value: typing.Union[dict, list], + evaluation_context: typing.Optional[EvaluationContext] = None, + ) -> FlagResolutionDetails[typing.Union[dict, list]]: + pass + + async def emit_provider_ready(self, details: ProviderEventDetails) -> None: + self.emit(ProviderEvent.PROVIDER_READY, details) + + async def emit_provider_configuration_changed( + self, details: ProviderEventDetails + ) -> None: + self.emit(ProviderEvent.PROVIDER_CONFIGURATION_CHANGED, details) + + async def emit_provider_error(self, details: ProviderEventDetails) -> None: + self.emit(ProviderEvent.PROVIDER_ERROR, details) + + async def emit_provider_stale(self, details: ProviderEventDetails) -> None: + self.emit(ProviderEvent.PROVIDER_STALE, details) + + async def emit(self, event: ProviderEvent, details: ProviderEventDetails) -> None: + if hasattr(self, "_on_emit"): + self._on_emit(self, event, details) \ No newline at end of file diff --git a/openfeature/provider/in_memory_provider.py b/openfeature/provider/in_memory_provider.py index 322f4ed6..7debd0d3 100644 --- a/openfeature/provider/in_memory_provider.py +++ b/openfeature/provider/in_memory_provider.py @@ -117,3 +117,56 @@ def _resolve( if flag is None: raise FlagNotFoundError(f"Flag '{flag_key}' not found") return flag.resolve(evaluation_context) + + +class AsyncInMemoryProvider(InMemoryProvider): + _flags: FlagStorage + + def __init__(self, flags: FlagStorage) -> None: + self._flags = flags.copy() + + def get_metadata(self) -> Metadata: + return InMemoryMetadata() + + def get_provider_hooks(self) -> typing.List[Hook]: + return [] + + async def resolve_boolean_details( + self, + flag_key: str, + default_value: bool, + evaluation_context: typing.Optional[EvaluationContext] = None, + ) -> FlagResolutionDetails[bool]: + return self._resolve(flag_key, evaluation_context) + + async def resolve_string_details( + self, + flag_key: str, + default_value: str, + evaluation_context: typing.Optional[EvaluationContext] = None, + ) -> FlagResolutionDetails[str]: + return self._resolve(flag_key, evaluation_context) + + async def resolve_integer_details( + self, + flag_key: str, + default_value: int, + evaluation_context: typing.Optional[EvaluationContext] = None, + ) -> FlagResolutionDetails[int]: + return self._resolve(flag_key, evaluation_context) + + async def resolve_float_details( + self, + flag_key: str, + default_value: float, + evaluation_context: typing.Optional[EvaluationContext] = None, + ) -> FlagResolutionDetails[float]: + return self._resolve(flag_key, evaluation_context) + + async def resolve_object_details( + self, + flag_key: str, + default_value: typing.Union[dict, list], + evaluation_context: typing.Optional[EvaluationContext] = None, + ) -> FlagResolutionDetails[typing.Union[dict, list]]: + return self._resolve(flag_key, evaluation_context) \ No newline at end of file diff --git a/tests/test_async_client.py b/tests/test_async_client.py new file mode 100644 index 00000000..70e312d3 --- /dev/null +++ b/tests/test_async_client.py @@ -0,0 +1,183 @@ +import time +import uuid +from concurrent.futures import ThreadPoolExecutor +from unittest.mock import MagicMock, AsyncMock + +import pytest + +from openfeature.api import add_hooks, clear_hooks, get_client_async, set_provider +from openfeature.client import AsyncOpenFeatureClient +from openfeature.event import ProviderEvent, ProviderEventDetails +from openfeature.flag_evaluation import FlagResolutionDetails, Reason +from openfeature.provider import FeatureProvider +from openfeature.provider.in_memory_provider import InMemoryFlag, AsyncInMemoryProvider +from openfeature.provider.no_op_provider import NoOpProvider +from openfeature.hook import AsyncHook + + +async_hook = MagicMock(spec=AsyncHook) +async_hook.before = AsyncMock(return_value=None) +async_hook.after = AsyncMock(return_value=None) + + + +@pytest.mark.asyncio +async def test_should_pass_flag_metadata_from_resolution_to_evaluation_details_async(): + # Given + clear_hooks() + add_hooks([async_hook]) + provider = AsyncInMemoryProvider( + { + "Key": InMemoryFlag( + "true", + {"true": True, "false": False}, + flag_metadata={"foo": "bar"}, + ) + } + ) + set_provider(provider, "my-async-client") + + client = AsyncOpenFeatureClient("my-async-client", None) + + # When + details = await client.get_boolean_details(flag_key="Key", default_value=False) + # Then + clear_hooks() + assert details is not None + assert details.flag_metadata == {"foo": "bar"} + + +@pytest.mark.asyncio +async def test_should_return_client_metadata_with_domain_async(): + # Given + client = AsyncOpenFeatureClient("my-async-client", None, NoOpProvider()) + # When + metadata = client.get_metadata() + # Then + assert metadata is not None + assert metadata.domain == "my-async-client" + + +def test_add_remove_event_handler_async(): + # Given + provider = NoOpProvider() + set_provider(provider) + + spy = MagicMock() + + client = get_client_async() + client.add_handler( + ProviderEvent.PROVIDER_CONFIGURATION_CHANGED, spy.provider_configuration_changed + ) + client.remove_handler( + ProviderEvent.PROVIDER_CONFIGURATION_CHANGED, spy.provider_configuration_changed + ) + + provider_details = ProviderEventDetails(message="message") + + # When + provider.emit_provider_configuration_changed(provider_details) + + # Then + spy.provider_configuration_changed.assert_not_called() + + +def test_client_handlers_thread_safety_async(): + provider = NoOpProvider() + set_provider(provider) + + def add_handlers_task(): + def handler(*args, **kwargs): + time.sleep(0.005) + + for _ in range(10): + time.sleep(0.01) + client = get_client_async(str(uuid.uuid4())) + client.add_handler(ProviderEvent.PROVIDER_CONFIGURATION_CHANGED, handler) + + def emit_events_task(): + for _ in range(10): + time.sleep(0.01) + provider.emit_provider_configuration_changed(ProviderEventDetails()) + + with ThreadPoolExecutor(max_workers=2) as executor: + f1 = executor.submit(add_handlers_task) + f2 = executor.submit(emit_events_task) + f1.result() + f2.result() + +@pytest.mark.asyncio +async def test_evaluate_boolean_flag_details_async(): + # Given + provider = MagicMock(spec=FeatureProvider) + provider.resolve_boolean_details.return_value = FlagResolutionDetails( + value=True, + reason=Reason.TARGETING_MATCH, + ) + set_provider(provider) + client = get_client_async() + # When + flag = await client.evaluate_flag_details( + flag_type=bool, flag_key="Key", default_value=True + ) + + # Then + assert flag is not None + assert flag.value == True + +@pytest.mark.asyncio +async def test_evaluate_string_flag_details_async(): + # Given + provider = MagicMock(spec=FeatureProvider) + provider.resolve_string_details.return_value = FlagResolutionDetails( + value="String", + reason=Reason.TARGETING_MATCH, + ) + set_provider(provider) + client = get_client_async() + # When + flag = await client.evaluate_flag_details( + flag_type=str, flag_key="Key", default_value="String" + ) + + # Then + assert flag is not None + assert flag.value == "String" + +@pytest.mark.asyncio +async def test_evaluate_integer_flag_details(): + # Given + provider = MagicMock(spec=FeatureProvider) + provider.resolve_integer_details.return_value = FlagResolutionDetails( + value=100, + reason=Reason.TARGETING_MATCH, + ) + set_provider(provider) + client = get_client_async() + # When + flag = await client.evaluate_flag_details( + flag_type=int, flag_key="Key", default_value=100 + ) + + # Then + assert flag is not None + assert flag.value == 100 + +@pytest.mark.asyncio +async def test_evaluate_float_flag_details_async(): + # Given + provider = MagicMock(spec=FeatureProvider) + provider.resolve_float_details.return_value = FlagResolutionDetails( + value=10.23, + reason=Reason.TARGETING_MATCH, + ) + set_provider(provider) + client = get_client_async() + # When + flag = await client.evaluate_flag_details( + flag_type=float, flag_key="Key", default_value=10.23 + ) + + # Then + assert flag is not None + assert flag.value == 10.23 \ No newline at end of file diff --git a/tests/test_client.py b/tests/test_client.py index a98f26e2..b51c460c 100644 --- a/tests/test_client.py +++ b/tests/test_client.py @@ -384,95 +384,3 @@ def emit_events_task(): f2 = executor.submit(emit_events_task) f1.result() f2.result() - -@pytest.mark.asyncio -async def test_evaluate_boolean_flag_details_async(): - # Given - provider = MagicMock(spec=FeatureProvider) - provider.resolve_boolean_details_async.return_value = FlagResolutionDetails( - value=True, - reason=Reason.TARGETING_MATCH, - ) - set_provider(provider) - client = get_client() - # When - flag = await client.evaluate_flag_details_async( - flag_type=bool, flag_key="Key", default_value=True - ) - - # Then - assert flag is not None - assert flag.value == True - -@pytest.mark.asyncio -async def test_evaluate_string_flag_details_async(): - # Given - provider = MagicMock(spec=FeatureProvider) - provider.resolve_string_details_async.return_value = FlagResolutionDetails( - value="String", - reason=Reason.TARGETING_MATCH, - ) - set_provider(provider) - client = get_client() - # When - flag = await client.evaluate_flag_details_async( - flag_type=str, flag_key="Key", default_value="String" - ) - - # Then - assert flag is not None - assert flag.value == "String" - -@pytest.mark.asyncio -async def test_evaluate_integer_flag_details_async(): - # Given - provider = MagicMock(spec=FeatureProvider) - provider.resolve_integer_details_async.return_value = FlagResolutionDetails( - value=100, - reason=Reason.TARGETING_MATCH, - ) - set_provider(provider) - client = get_client() - # When - flag = await client.evaluate_flag_details_async( - flag_type=int, flag_key="Key", default_value=100 - ) - - # Then - assert flag is not None - assert flag.value == 100 - -@pytest.mark.asyncio -async def test_evaluate_float_flag_details_async(): - # Given - provider = MagicMock(spec=FeatureProvider) - provider.resolve_float_details_async.return_value = FlagResolutionDetails( - value=10.23, - reason=Reason.TARGETING_MATCH, - ) - set_provider(provider) - client = get_client() - # When - flag = await client.evaluate_flag_details_async( - flag_type=float, flag_key="Key", default_value=10.23 - ) - - # Then - assert flag is not None - assert flag.value == 10.23 - - -@pytest.mark.asyncio -async def test_allow_not_implemented_async_functions(): - # Given - provider = NoOpProvider() - set_provider(provider) - # When - with pytest.raises(NotImplementedError) as exc_info: - flag = await provider.resolve_boolean_details_async( - flag_key="Key", default_value=True - ) - raise Exception(flag) - - # Then - assert "does not support async operations" in str(exc_info.value) From 7c48c38730997f4cd9b5528778b1f1c162571578 Mon Sep 17 00:00:00 2001 From: leohoare Date: Fri, 8 Nov 2024 16:15:02 +1100 Subject: [PATCH 4/6] remove unused functions Signed-off-by: leohoare --- openfeature/provider/__init__.py | 42 +------------------------------- 1 file changed, 1 insertion(+), 41 deletions(-) diff --git a/openfeature/provider/__init__.py b/openfeature/provider/__init__.py index 4bb30be7..02d1f8b9 100644 --- a/openfeature/provider/__init__.py +++ b/openfeature/provider/__init__.py @@ -111,14 +111,6 @@ def resolve_boolean_details( ) -> FlagResolutionDetails[bool]: pass - async def resolve_boolean_details_async( - self, - flag_key: str, - default_value: bool, - evaluation_context: typing.Optional[EvaluationContext] = None, - ) -> FlagResolutionDetails[bool]: - raise NotImplementedError(f"{self.__class__.__name__} does not support async operations") - @abstractmethod def resolve_string_details( self, @@ -128,14 +120,6 @@ def resolve_string_details( ) -> FlagResolutionDetails[str]: pass - async def resolve_string_details_async( - self, - flag_key: str, - default_value: str, - evaluation_context: typing.Optional[EvaluationContext] = None, - ) -> FlagResolutionDetails[str]: - raise NotImplementedError(f"{self.__class__.__name__} does not support async operations") - @abstractmethod def resolve_integer_details( self, @@ -145,14 +129,6 @@ def resolve_integer_details( ) -> FlagResolutionDetails[int]: pass - async def resolve_integer_details_async( - self, - flag_key: str, - default_value: int, - evaluation_context: typing.Optional[EvaluationContext] = None, - ) -> FlagResolutionDetails[int]: - raise NotImplementedError(f"{self.__class__.__name__} does not support async operations") - @abstractmethod def resolve_float_details( self, @@ -162,14 +138,6 @@ def resolve_float_details( ) -> FlagResolutionDetails[float]: pass - async def resolve_float_details_async( - self, - flag_key: str, - default_value: float, - evaluation_context: typing.Optional[EvaluationContext] = None, - ) -> FlagResolutionDetails[float]: - raise NotImplementedError(f"{self.__class__.__name__} does not support async operations") - @abstractmethod def resolve_object_details( self, @@ -179,14 +147,6 @@ def resolve_object_details( ) -> FlagResolutionDetails[typing.Union[dict, list]]: pass - async def resolve_object_details_async( - self, - flag_key: str, - default_value: typing.Union[dict, list], - evaluation_context: typing.Optional[EvaluationContext] = None, - ) -> FlagResolutionDetails[typing.Union[dict, list]]: - raise NotImplementedError(f"{self.__class__.__name__} does not support async operations") - def emit_provider_ready(self, details: ProviderEventDetails) -> None: self.emit(ProviderEvent.PROVIDER_READY, details) @@ -293,4 +253,4 @@ async def emit_provider_stale(self, details: ProviderEventDetails) -> None: async def emit(self, event: ProviderEvent, details: ProviderEventDetails) -> None: if hasattr(self, "_on_emit"): - self._on_emit(self, event, details) \ No newline at end of file + self._on_emit(self, event, details) From 2615979a55b1e510f949fbd157f70fd10d686964 Mon Sep 17 00:00:00 2001 From: leohoare Date: Fri, 8 Nov 2024 16:18:06 +1100 Subject: [PATCH 5/6] cleanup Signed-off-by: leohoare --- openfeature/api.py | 8 ++++++-- openfeature/client.py | 26 +++++++++++++++----------- 2 files changed, 21 insertions(+), 13 deletions(-) diff --git a/openfeature/api.py b/openfeature/api.py index db82d016..f0acc3c7 100644 --- a/openfeature/api.py +++ b/openfeature/api.py @@ -1,7 +1,7 @@ import typing from openfeature import _event_support -from openfeature.client import OpenFeatureClient, AsyncOpenFeatureClient +from openfeature.client import AsyncOpenFeatureClient, OpenFeatureClient from openfeature.evaluation_context import EvaluationContext from openfeature.event import ( EventHandler, @@ -38,9 +38,13 @@ def get_client( ) -> OpenFeatureClient: return OpenFeatureClient(domain=domain, version=version) -def get_client_async(domain: typing.Optional[str] = None, version: typing.Optional[str] = None) -> OpenFeatureClient: + +def get_client_async( + domain: typing.Optional[str] = None, version: typing.Optional[str] = None +) -> AsyncOpenFeatureClient: return AsyncOpenFeatureClient(domain=domain, version=version) + def set_provider( provider: FeatureProvider, domain: typing.Optional[str] = None ) -> None: diff --git a/openfeature/client.py b/openfeature/client.py index d22a3540..dd7c4742 100644 --- a/openfeature/client.py +++ b/openfeature/client.py @@ -190,7 +190,7 @@ def get_integer_details( evaluation_context, flag_evaluation_options, ) - + def get_float_value( self, flag_key: str, @@ -204,7 +204,7 @@ def get_float_value( evaluation_context, flag_evaluation_options, ).value - + def get_float_details( self, flag_key: str, @@ -448,7 +448,6 @@ def _create_provider_evaluation( error_message=resolution.error_message, ) - def add_handler(self, event: ProviderEvent, handler: EventHandler) -> None: _event_support.add_client_handler(self, event, handler) @@ -470,6 +469,7 @@ def _typecheck_flag_value(value: typing.Any, flag_type: FlagType) -> None: if not isinstance(value, _type): raise TypeMismatchError(f"Expected type {_type} but got {type(value)}") + class AsyncOpenFeatureClient: def __init__( self, @@ -611,7 +611,7 @@ async def get_float_details( evaluation_context, flag_evaluation_options, ) - + async def get_object_value( self, flag_key: str, @@ -625,7 +625,7 @@ async def get_object_value( evaluation_context, flag_evaluation_options, ).value - + async def get_object_details( self, flag_key: str, @@ -757,7 +757,9 @@ async def evaluate_flag_details( # noqa: PLR0915 return flag_evaluation except OpenFeatureError as err: - await error_hooks_async(flag_type, hook_context, err, reversed_merged_hooks, hook_hints) + await error_hooks_async( + flag_type, hook_context, err, reversed_merged_hooks, hook_hints + ) return FlagEvaluationDetails( flag_key=flag_key, @@ -773,7 +775,9 @@ async def evaluate_flag_details( # noqa: PLR0915 "Unable to correctly evaluate flag with key: '%s'", flag_key ) - await error_hooks_async(flag_type, hook_context, err, reversed_merged_hooks, hook_hints) + await error_hooks_async( + flag_type, hook_context, err, reversed_merged_hooks, hook_hints + ) error_message = getattr(err, "error_message", str(err)) return FlagEvaluationDetails( @@ -785,8 +789,9 @@ async def evaluate_flag_details( # noqa: PLR0915 ) finally: - await after_all_hooks_async(flag_type, hook_context, reversed_merged_hooks, hook_hints) - + await after_all_hooks_async( + flag_type, hook_context, reversed_merged_hooks, hook_hints + ) async def _create_provider_evaluation( self, @@ -840,9 +845,8 @@ async def _create_provider_evaluation( error_message=resolution.error_message, ) - def add_handler(self, event: ProviderEvent, handler: EventHandler) -> None: _event_support.add_client_handler(self, event, handler) def remove_handler(self, event: ProviderEvent, handler: EventHandler) -> None: - _event_support.remove_client_handler(self, event, handler) \ No newline at end of file + _event_support.remove_client_handler(self, event, handler) From 272d659f04e1c85fcdee8ccd340181996bce61b4 Mon Sep 17 00:00:00 2001 From: leohoare Date: Tue, 12 Nov 2024 22:13:15 +1100 Subject: [PATCH 6/6] fix unit tests, increased coverage, simplified implementation Signed-off-by: leohoare --- openfeature/client.py | 56 +++----- openfeature/hook/_hook_support.py | 8 ++ openfeature/provider/__init__.py | 45 +------ openfeature/provider/no_op_provider.py | 62 +++++++++ tests/conftest.py | 17 ++- tests/hook/conftest.py | 12 ++ tests/hook/test_hook_support.py | 102 ++++++++++++++- tests/test_async_client.py | 172 +++++++++++++------------ 8 files changed, 306 insertions(+), 168 deletions(-) diff --git a/openfeature/client.py b/openfeature/client.py index dd7c4742..4adf2974 100644 --- a/openfeature/client.py +++ b/openfeature/client.py @@ -470,32 +470,7 @@ def _typecheck_flag_value(value: typing.Any, flag_type: FlagType) -> None: raise TypeMismatchError(f"Expected type {_type} but got {type(value)}") -class AsyncOpenFeatureClient: - def __init__( - self, - domain: typing.Optional[str], - version: typing.Optional[str], - context: typing.Optional[EvaluationContext] = None, - hooks: typing.Optional[typing.List[Hook]] = None, - ) -> None: - self.domain = domain - self.version = version - self.context = context or EvaluationContext() - self.hooks = hooks or [] - - @property - def provider(self) -> FeatureProvider: - return provider_registry.get_provider(self.domain) - - def get_provider_status(self) -> ProviderStatus: - return provider_registry.get_provider_status(self.provider) - - def get_metadata(self) -> ClientMetadata: - return ClientMetadata(domain=self.domain) - - def add_hooks(self, hooks: typing.List[Hook]) -> None: - self.hooks = self.hooks + hooks - +class AsyncOpenFeatureClient(OpenFeatureClient): async def get_boolean_value( self, flag_key: str, @@ -503,12 +478,13 @@ async def get_boolean_value( evaluation_context: typing.Optional[EvaluationContext] = None, flag_evaluation_options: typing.Optional[FlagEvaluationOptions] = None, ) -> bool: - return await self.get_boolean_value( + details = await self.get_boolean_details( flag_key, default_value, evaluation_context, flag_evaluation_options, ) + return details.value async def get_boolean_details( self, @@ -532,12 +508,13 @@ async def get_string_value( evaluation_context: typing.Optional[EvaluationContext] = None, flag_evaluation_options: typing.Optional[FlagEvaluationOptions] = None, ) -> str: - return await self.get_string_details( + details = await self.get_string_details( flag_key, default_value, evaluation_context, flag_evaluation_options, - ).value + ) + return details.value async def get_string_details( self, @@ -561,12 +538,13 @@ async def get_integer_value( evaluation_context: typing.Optional[EvaluationContext] = None, flag_evaluation_options: typing.Optional[FlagEvaluationOptions] = None, ) -> int: - return await self.get_integer_details( + details = await self.get_integer_details( flag_key, default_value, evaluation_context, flag_evaluation_options, - ).value + ) + return details.value async def get_integer_details( self, @@ -590,12 +568,13 @@ async def get_float_value( evaluation_context: typing.Optional[EvaluationContext] = None, flag_evaluation_options: typing.Optional[FlagEvaluationOptions] = None, ) -> float: - return await self.get_float_details( + details = await self.get_float_details( flag_key, default_value, evaluation_context, flag_evaluation_options, - ).value + ) + return details.value async def get_float_details( self, @@ -619,12 +598,13 @@ async def get_object_value( evaluation_context: typing.Optional[EvaluationContext] = None, flag_evaluation_options: typing.Optional[FlagEvaluationOptions] = None, ) -> typing.Union[dict, list]: - return await self.get_object_details( + details = await self.get_object_details( flag_key, default_value, evaluation_context, flag_evaluation_options, - ).value + ) + return details.value async def get_object_details( self, @@ -844,9 +824,3 @@ async def _create_provider_evaluation( error_code=resolution.error_code, error_message=resolution.error_message, ) - - def add_handler(self, event: ProviderEvent, handler: EventHandler) -> None: - _event_support.add_client_handler(self, event, handler) - - def remove_handler(self, event: ProviderEvent, handler: EventHandler) -> None: - _event_support.remove_client_handler(self, event, handler) diff --git a/openfeature/hook/_hook_support.py b/openfeature/hook/_hook_support.py index c36ebedc..ce817b36 100644 --- a/openfeature/hook/_hook_support.py +++ b/openfeature/hook/_hook_support.py @@ -21,6 +21,7 @@ def error_hooks( flag_type=flag_type, hooks=hooks, hook_method=HookType.ERROR, **kwargs ) + async def error_hooks_async( flag_type: FlagType, hook_context: HookContext, @@ -33,6 +34,7 @@ async def error_hooks_async( flag_type=flag_type, hooks=hooks, hook_method=HookType.ERROR, **kwargs ) + def after_all_hooks( flag_type: FlagType, hook_context: HookContext, @@ -44,6 +46,7 @@ def after_all_hooks( flag_type=flag_type, hooks=hooks, hook_method=HookType.FINALLY_AFTER, **kwargs ) + async def after_all_hooks_async( flag_type: FlagType, hook_context: HookContext, @@ -68,6 +71,7 @@ def after_hooks( flag_type=flag_type, hooks=hooks, hook_method=HookType.AFTER, **kwargs ) + async def after_hooks_async( flag_type: FlagType, hook_context: HookContext, @@ -98,6 +102,7 @@ def before_hooks( return EvaluationContext() + async def before_hooks_async( flag_type: FlagType, hook_context: HookContext, @@ -137,6 +142,7 @@ def _execute_hooks( if hook.supports_flag_value_type(flag_type) ] + async def _execute_hooks_async( flag_type: FlagType, hooks: typing.List[Hook], @@ -183,6 +189,7 @@ def _execute_hooks_unchecked( if hook.supports_flag_value_type(flag_type) ] + async def _execute_hooks_async_unchecked( flag_type: FlagType, hooks: typing.List[Hook], @@ -228,6 +235,7 @@ def _execute_hook_checked( logger.exception(f"Exception when running {hook_method.value} hooks") return None + async def _execute_hook_checked_async( hook: Hook, hook_method: HookType, **kwargs: typing.Any ) -> typing.Optional[EvaluationContext]: diff --git a/openfeature/provider/__init__.py b/openfeature/provider/__init__.py index 02d1f8b9..f22a45ce 100644 --- a/openfeature/provider/__init__.py +++ b/openfeature/provider/__init__.py @@ -166,32 +166,7 @@ def emit(self, event: ProviderEvent, details: ProviderEventDetails) -> None: self._on_emit(self, event, details) -class AsyncFeatureProvider(FeatureProvider): - async def attach( - self, - on_emit: typing.Callable[ - [FeatureProvider, ProviderEvent, ProviderEventDetails], None - ], - ) -> None: - self._on_emit = on_emit - - async def detach(self) -> None: - if hasattr(self, "_on_emit"): - del self._on_emit - - async def initialize(self, evaluation_context: EvaluationContext) -> None: - pass - - async def shutdown(self) -> None: - pass - - @abstractmethod - async def get_metadata(self) -> Metadata: - pass - - async def get_provider_hooks(self) -> typing.List[Hook]: - return [] - +class AsyncAbstractProvider(AbstractProvider): @abstractmethod async def resolve_boolean_details( self, @@ -236,21 +211,3 @@ def resolve_object_details( evaluation_context: typing.Optional[EvaluationContext] = None, ) -> FlagResolutionDetails[typing.Union[dict, list]]: pass - - async def emit_provider_ready(self, details: ProviderEventDetails) -> None: - self.emit(ProviderEvent.PROVIDER_READY, details) - - async def emit_provider_configuration_changed( - self, details: ProviderEventDetails - ) -> None: - self.emit(ProviderEvent.PROVIDER_CONFIGURATION_CHANGED, details) - - async def emit_provider_error(self, details: ProviderEventDetails) -> None: - self.emit(ProviderEvent.PROVIDER_ERROR, details) - - async def emit_provider_stale(self, details: ProviderEventDetails) -> None: - self.emit(ProviderEvent.PROVIDER_STALE, details) - - async def emit(self, event: ProviderEvent, details: ProviderEventDetails) -> None: - if hasattr(self, "_on_emit"): - self._on_emit(self, event, details) diff --git a/openfeature/provider/no_op_provider.py b/openfeature/provider/no_op_provider.py index 070945c9..cba1d4fb 100644 --- a/openfeature/provider/no_op_provider.py +++ b/openfeature/provider/no_op_provider.py @@ -75,3 +75,65 @@ def resolve_object_details( reason=Reason.DEFAULT, variant=PASSED_IN_DEFAULT, ) + + +class AsyncNoOpProvider(NoOpProvider): + async def resolve_boolean_details( + self, + flag_key: str, + default_value: bool, + evaluation_context: typing.Optional[EvaluationContext] = None, + ) -> FlagResolutionDetails[bool]: + return FlagResolutionDetails( + value=default_value, + reason=Reason.DEFAULT, + variant=PASSED_IN_DEFAULT, + ) + + async def resolve_string_details( + self, + flag_key: str, + default_value: str, + evaluation_context: typing.Optional[EvaluationContext] = None, + ) -> FlagResolutionDetails[str]: + return FlagResolutionDetails( + value=default_value, + reason=Reason.DEFAULT, + variant=PASSED_IN_DEFAULT, + ) + + async def resolve_integer_details( + self, + flag_key: str, + default_value: int, + evaluation_context: typing.Optional[EvaluationContext] = None, + ) -> FlagResolutionDetails[int]: + return FlagResolutionDetails( + value=default_value, + reason=Reason.DEFAULT, + variant=PASSED_IN_DEFAULT, + ) + + async def resolve_float_details( + self, + flag_key: str, + default_value: float, + evaluation_context: typing.Optional[EvaluationContext] = None, + ) -> FlagResolutionDetails[float]: + return FlagResolutionDetails( + value=default_value, + reason=Reason.DEFAULT, + variant=PASSED_IN_DEFAULT, + ) + + async def resolve_object_details( + self, + flag_key: str, + default_value: typing.Union[dict, list], + evaluation_context: typing.Optional[EvaluationContext] = None, + ) -> FlagResolutionDetails[typing.Union[dict, list]]: + return FlagResolutionDetails( + value=default_value, + reason=Reason.DEFAULT, + variant=PASSED_IN_DEFAULT, + ) diff --git a/tests/conftest.py b/tests/conftest.py index 1f0a7982..b3dcc16a 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -1,7 +1,7 @@ import pytest from openfeature import api -from openfeature.provider.no_op_provider import NoOpProvider +from openfeature.provider.no_op_provider import AsyncNoOpProvider, NoOpProvider @pytest.fixture(autouse=True) @@ -13,7 +13,22 @@ def clear_providers(): api.clear_providers() +@pytest.fixture(autouse=True) +def clear_hooks_fixture(): + """ + For tests that use add_hooks(), we need to clear the hooks to avoid issues + in other tests. + """ + api.clear_hooks() + + @pytest.fixture() def no_op_provider_client(): api.set_provider(NoOpProvider()) return api.get_client() + + +@pytest.fixture() +def no_op_provider_client_async(): + api.set_provider(AsyncNoOpProvider()) + return api.get_client_async("my-async-client") diff --git a/tests/hook/conftest.py b/tests/hook/conftest.py index 5a3d8092..bf4af84e 100644 --- a/tests/hook/conftest.py +++ b/tests/hook/conftest.py @@ -3,6 +3,7 @@ import pytest from openfeature.evaluation_context import EvaluationContext +from openfeature.hook import AsyncHook @pytest.fixture() @@ -14,3 +15,14 @@ def mock_hook(): mock_hook.error.return_value = None mock_hook.finally_after.return_value = None return mock_hook + + +@pytest.fixture() +def mock_hook_async(): + mock_hook = AsyncHook() + mock_hook.supports_flag_value_type = mock.MagicMock(return_value=True) + mock_hook.before = mock.AsyncMock(return_value=None) + mock_hook.after = mock.AsyncMock(return_value=None) + mock_hook.error = mock.AsyncMock(return_value=None) + mock_hook.finally_after = mock.AsyncMock(return_value=None) + return mock_hook diff --git a/tests/hook/test_hook_support.py b/tests/hook/test_hook_support.py index 64bb8f6f..49997107 100644 --- a/tests/hook/test_hook_support.py +++ b/tests/hook/test_hook_support.py @@ -1,16 +1,20 @@ -from unittest.mock import ANY, MagicMock +from unittest.mock import ANY, AsyncMock, MagicMock import pytest from openfeature.client import ClientMetadata from openfeature.evaluation_context import EvaluationContext from openfeature.flag_evaluation import FlagEvaluationDetails, FlagType -from openfeature.hook import Hook, HookContext +from openfeature.hook import AsyncHook, Hook, HookContext from openfeature.hook._hook_support import ( after_all_hooks, + after_all_hooks_async, after_hooks, + after_hooks_async, before_hooks, + before_hooks_async, error_hooks, + error_hooks_async, ) from openfeature.immutable_dict.mapping_proxy_type import MappingProxyType from openfeature.provider.metadata import Metadata @@ -86,6 +90,23 @@ def test_error_hooks_run_error_method(mock_hook): ) +@pytest.mark.asyncio +async def test_error_hooks_run_error_method_async(mock_hook_async): + # Given + hook_context = HookContext("flag_key", FlagType.BOOLEAN, True, "") + hook_hints = MappingProxyType({}) + # When + await error_hooks_async( + FlagType.BOOLEAN, hook_context, Exception, [mock_hook_async], hook_hints + ) + # Then + mock_hook_async.supports_flag_value_type.assert_called_once() + mock_hook_async.error.assert_called_once() + mock_hook_async.error.assert_called_with( + hook_context=hook_context, exception=ANY, hints=hook_hints + ) + + def test_before_hooks_run_before_method(mock_hook): # Given hook_context = HookContext("flag_key", FlagType.BOOLEAN, True, "") @@ -98,6 +119,23 @@ def test_before_hooks_run_before_method(mock_hook): mock_hook.before.assert_called_with(hook_context=hook_context, hints=hook_hints) +@pytest.mark.asyncio +async def test_before_hooks_run_before_method_async(mock_hook_async): + # Given + hook_context = HookContext("flag_key", FlagType.BOOLEAN, True, "") + hook_hints = MappingProxyType({}) + # When + await before_hooks_async( + FlagType.BOOLEAN, hook_context, [mock_hook_async], hook_hints + ) + # Then + mock_hook_async.supports_flag_value_type.assert_called_once() + mock_hook_async.before.assert_called_once() + mock_hook_async.before.assert_called_with( + hook_context=hook_context, hints=hook_hints + ) + + def test_before_hooks_merges_evaluation_contexts(): # Given hook_context = HookContext("flag_key", FlagType.BOOLEAN, True, "") @@ -115,6 +153,25 @@ def test_before_hooks_merges_evaluation_contexts(): assert context == EvaluationContext("bar", {"key_1": "val_1", "key_2": "val_2"}) +@pytest.mark.asyncio +async def test_before_hooks_async_merges_evaluation_contexts(): + # Given + hook_context = HookContext("flag_key", FlagType.BOOLEAN, True, "") + hook_1 = AsyncHook() + hook_1.before = AsyncMock(return_value=EvaluationContext("foo", {"key_1": "val_1"})) + hook_2 = AsyncHook() + hook_2.before = AsyncMock(return_value=EvaluationContext("bar", {"key_2": "val_2"})) + hook_3 = AsyncHook() + hook_3.before = AsyncMock(return_value=None) + # When + context = await before_hooks_async( + FlagType.BOOLEAN, hook_context, [hook_1, hook_2, hook_3] + ) + + # Then + assert context == EvaluationContext("bar", {"key_1": "val_1", "key_2": "val_2"}) + + def test_after_hooks_run_after_method(mock_hook): # Given hook_context = HookContext("flag_key", FlagType.BOOLEAN, True, "") @@ -134,6 +191,30 @@ def test_after_hooks_run_after_method(mock_hook): ) +@pytest.mark.asyncio +async def test_after_hooks_run_after_method_async(mock_hook_async): + # Given + hook_context = HookContext("flag_key", FlagType.BOOLEAN, True, "") + flag_evaluation_details = FlagEvaluationDetails( + hook_context.flag_key, "val", "unknown" + ) + hook_hints = MappingProxyType({}) + # When + await after_hooks_async( + FlagType.BOOLEAN, + hook_context, + flag_evaluation_details, + [mock_hook_async], + hook_hints, + ) + # Then + mock_hook_async.supports_flag_value_type.assert_called_once() + mock_hook_async.after.assert_called_once() + mock_hook_async.after.assert_called_with( + hook_context=hook_context, details=flag_evaluation_details, hints=hook_hints + ) + + def test_finally_after_hooks_run_finally_after_method(mock_hook): # Given hook_context = HookContext("flag_key", FlagType.BOOLEAN, True, "") @@ -146,3 +227,20 @@ def test_finally_after_hooks_run_finally_after_method(mock_hook): mock_hook.finally_after.assert_called_with( hook_context=hook_context, hints=hook_hints ) + + +@pytest.mark.asyncio +async def test_finally_after_hooks_run_finally_after_method_async(mock_hook_async): + # Given + hook_context = HookContext("flag_key", FlagType.BOOLEAN, True, "") + hook_hints = MappingProxyType({}) + # When + await after_all_hooks_async( + FlagType.BOOLEAN, hook_context, [mock_hook_async], hook_hints + ) + # Then + mock_hook_async.supports_flag_value_type.assert_called_once() + mock_hook_async.finally_after.assert_called_once() + mock_hook_async.finally_after.assert_called_with( + hook_context=hook_context, hints=hook_hints + ) diff --git a/tests/test_async_client.py b/tests/test_async_client.py index 70e312d3..97a11bc7 100644 --- a/tests/test_async_client.py +++ b/tests/test_async_client.py @@ -1,58 +1,79 @@ import time import uuid from concurrent.futures import ThreadPoolExecutor -from unittest.mock import MagicMock, AsyncMock +from unittest.mock import AsyncMock, MagicMock import pytest -from openfeature.api import add_hooks, clear_hooks, get_client_async, set_provider +from openfeature.api import add_hooks, get_client_async, set_provider from openfeature.client import AsyncOpenFeatureClient from openfeature.event import ProviderEvent, ProviderEventDetails +from openfeature.exception import ErrorCode, OpenFeatureError from openfeature.flag_evaluation import FlagResolutionDetails, Reason -from openfeature.provider import FeatureProvider -from openfeature.provider.in_memory_provider import InMemoryFlag, AsyncInMemoryProvider -from openfeature.provider.no_op_provider import NoOpProvider from openfeature.hook import AsyncHook - +from openfeature.provider import FeatureProvider, ProviderStatus +from openfeature.provider.in_memory_provider import AsyncInMemoryProvider, InMemoryFlag +from openfeature.provider.no_op_provider import NoOpProvider async_hook = MagicMock(spec=AsyncHook) async_hook.before = AsyncMock(return_value=None) async_hook.after = AsyncMock(return_value=None) - +@pytest.mark.parametrize( + "default, variants, get_method, expected_value", + ( + ("true", {"true": True, "false": False}, "get_boolean", True), + ("String", {"String": "Variant"}, "get_string", "Variant"), + ("Number", {"Number": 100}, "get_integer", 100), + ("Float", {"Float": 10.23}, "get_float", 10.23), + ( + "Object", + {"Object": {"some": "object"}}, + "get_object", + {"some": "object"}, + ), + ), +) @pytest.mark.asyncio -async def test_should_pass_flag_metadata_from_resolution_to_evaluation_details_async(): +async def test_flag_resolution_to_evaluation_details_async( + default, variants, get_method, expected_value, clear_hooks_fixture +): # Given - clear_hooks() add_hooks([async_hook]) provider = AsyncInMemoryProvider( { "Key": InMemoryFlag( - "true", - {"true": True, "false": False}, + default, + variants, flag_metadata={"foo": "bar"}, ) } ) set_provider(provider, "my-async-client") - client = AsyncOpenFeatureClient("my-async-client", None) - + client.add_hooks([async_hook]) # When - details = await client.get_boolean_details(flag_key="Key", default_value=False) + details = await getattr(client, f"{get_method}_details")( + flag_key="Key", default_value=None + ) + value = await getattr(client, f"{get_method}_value")( + flag_key="Key", default_value=None + ) # Then - clear_hooks() assert details is not None assert details.flag_metadata == {"foo": "bar"} + assert details.value == expected_value + assert details.value == value @pytest.mark.asyncio -async def test_should_return_client_metadata_with_domain_async(): +async def test_should_return_client_metadata_with_domain_async( + no_op_provider_client_async, +): # Given - client = AsyncOpenFeatureClient("my-async-client", None, NoOpProvider()) # When - metadata = client.get_metadata() + metadata = no_op_provider_client_async.get_metadata() # Then assert metadata is not None assert metadata.domain == "my-async-client" @@ -62,7 +83,6 @@ def test_add_remove_event_handler_async(): # Given provider = NoOpProvider() set_provider(provider) - spy = MagicMock() client = get_client_async() @@ -106,78 +126,70 @@ def emit_events_task(): f1.result() f2.result() -@pytest.mark.asyncio -async def test_evaluate_boolean_flag_details_async(): - # Given - provider = MagicMock(spec=FeatureProvider) - provider.resolve_boolean_details.return_value = FlagResolutionDetails( - value=True, - reason=Reason.TARGETING_MATCH, - ) - set_provider(provider) - client = get_client_async() - # When - flag = await client.evaluate_flag_details( - flag_type=bool, flag_key="Key", default_value=True - ) - - # Then - assert flag is not None - assert flag.value == True +@pytest.mark.parametrize( + "provider_status, error_code", + ( + (ProviderStatus.NOT_READY, ErrorCode.PROVIDER_NOT_READY), + (ProviderStatus.FATAL, ErrorCode.PROVIDER_FATAL), + ), +) @pytest.mark.asyncio -async def test_evaluate_string_flag_details_async(): +async def test_should_shortcircuit_if_provider_is_not_ready( + no_op_provider_client_async, monkeypatch, provider_status, error_code +): # Given - provider = MagicMock(spec=FeatureProvider) - provider.resolve_string_details.return_value = FlagResolutionDetails( - value="String", - reason=Reason.TARGETING_MATCH, + monkeypatch.setattr( + no_op_provider_client_async, + "get_provider_status", + lambda: provider_status, ) - set_provider(provider) - client = get_client_async() + spy_hook = MagicMock(spec=AsyncHook) + spy_hook.before = AsyncMock(return_value=None) + no_op_provider_client_async.add_hooks([spy_hook]) # When - flag = await client.evaluate_flag_details( - flag_type=str, flag_key="Key", default_value="String" + flag_details = await no_op_provider_client_async.get_boolean_details( + flag_key="Key", default_value=True ) - # Then - assert flag is not None - assert flag.value == "String" - + assert flag_details is not None + assert flag_details.value + assert flag_details.reason == Reason.ERROR + assert flag_details.error_code == error_code + spy_hook.error.assert_called_once() + + +@pytest.mark.parametrize( + "expected_type, get_method, default_value", + ( + (bool, "get_boolean_details", True), + (str, "get_string_details", "default"), + (int, "get_integer_details", 100), + (float, "get_float_details", 10.23), + (dict, "get_object_details", {"some": "object"}), + ), +) @pytest.mark.asyncio -async def test_evaluate_integer_flag_details(): +async def test_handle_an_open_feature_exception_thrown_by_a_provider_async( + expected_type, + get_method, + default_value, + no_op_provider_client_async, +): # Given - provider = MagicMock(spec=FeatureProvider) - provider.resolve_integer_details.return_value = FlagResolutionDetails( - value=100, - reason=Reason.TARGETING_MATCH, - ) - set_provider(provider) - client = get_client_async() - # When - flag = await client.evaluate_flag_details( - flag_type=int, flag_key="Key", default_value=100 + exception_hook = AsyncHook() + exception_hook.after = AsyncMock( + side_effect=OpenFeatureError(ErrorCode.GENERAL, "error_message") ) + no_op_provider_client_async.add_hooks([exception_hook]) - # Then - assert flag is not None - assert flag.value == 100 - -@pytest.mark.asyncio -async def test_evaluate_float_flag_details_async(): - # Given - provider = MagicMock(spec=FeatureProvider) - provider.resolve_float_details.return_value = FlagResolutionDetails( - value=10.23, - reason=Reason.TARGETING_MATCH, - ) - set_provider(provider) - client = get_client_async() # When - flag = await client.evaluate_flag_details( - flag_type=float, flag_key="Key", default_value=10.23 + flag_details = await getattr(no_op_provider_client_async, get_method)( + flag_key="Key", default_value=default_value ) - # Then - assert flag is not None - assert flag.value == 10.23 \ No newline at end of file + assert flag_details is not None + assert flag_details.value + assert isinstance(flag_details.value, expected_type) + assert flag_details.reason == Reason.ERROR + assert flag_details.error_message == "error_message"