diff --git a/openfeature/api.py b/openfeature/api.py index c95d10ac..f0acc3c7 100644 --- a/openfeature/api.py +++ b/openfeature/api.py @@ -1,7 +1,7 @@ import typing from openfeature import _event_support -from openfeature.client import OpenFeatureClient +from openfeature.client import AsyncOpenFeatureClient, OpenFeatureClient from openfeature.evaluation_context import EvaluationContext from openfeature.event import ( EventHandler, @@ -39,6 +39,12 @@ def get_client( return OpenFeatureClient(domain=domain, version=version) +def get_client_async( + domain: typing.Optional[str] = None, version: typing.Optional[str] = None +) -> AsyncOpenFeatureClient: + return AsyncOpenFeatureClient(domain=domain, version=version) + + def set_provider( provider: FeatureProvider, domain: typing.Optional[str] = None ) -> None: diff --git a/openfeature/client.py b/openfeature/client.py index 9e4518ec..4adf2974 100644 --- a/openfeature/client.py +++ b/openfeature/client.py @@ -23,9 +23,13 @@ from openfeature.hook import Hook, HookContext from openfeature.hook._hook_support import ( after_all_hooks, + after_all_hooks_async, after_hooks, + after_hooks_async, before_hooks, + before_hooks_async, error_hooks, + error_hooks_async, ) from openfeature.provider import FeatureProvider, ProviderStatus from openfeature.provider._registry import provider_registry @@ -33,6 +37,7 @@ __all__ = [ "ClientMetadata", "OpenFeatureClient", + "AsyncOpenFeatureClient", ] logger = logging.getLogger("openfeature") @@ -463,3 +468,359 @@ def _typecheck_flag_value(value: typing.Any, flag_type: FlagType) -> None: raise GeneralError(error_message="Unknown flag type") if not isinstance(value, _type): raise TypeMismatchError(f"Expected type {_type} but got {type(value)}") + + +class AsyncOpenFeatureClient(OpenFeatureClient): + async def get_boolean_value( + self, + flag_key: str, + default_value: bool, + evaluation_context: typing.Optional[EvaluationContext] = None, + flag_evaluation_options: typing.Optional[FlagEvaluationOptions] = None, + ) -> bool: + details = await self.get_boolean_details( + flag_key, + default_value, + evaluation_context, + flag_evaluation_options, + ) + return details.value + + async def get_boolean_details( + self, + flag_key: str, + default_value: bool, + evaluation_context: typing.Optional[EvaluationContext] = None, + flag_evaluation_options: typing.Optional[FlagEvaluationOptions] = None, + ) -> FlagEvaluationDetails[bool]: + return await self.evaluate_flag_details( + FlagType.BOOLEAN, + flag_key, + default_value, + evaluation_context, + flag_evaluation_options, + ) + + async def get_string_value( + self, + flag_key: str, + default_value: str, + evaluation_context: typing.Optional[EvaluationContext] = None, + flag_evaluation_options: typing.Optional[FlagEvaluationOptions] = None, + ) -> str: + details = await self.get_string_details( + flag_key, + default_value, + evaluation_context, + flag_evaluation_options, + ) + return details.value + + async def get_string_details( + self, + flag_key: str, + default_value: str, + evaluation_context: typing.Optional[EvaluationContext] = None, + flag_evaluation_options: typing.Optional[FlagEvaluationOptions] = None, + ) -> FlagEvaluationDetails[str]: + return await self.evaluate_flag_details( + FlagType.STRING, + flag_key, + default_value, + evaluation_context, + flag_evaluation_options, + ) + + async def get_integer_value( + self, + flag_key: str, + default_value: int, + evaluation_context: typing.Optional[EvaluationContext] = None, + flag_evaluation_options: typing.Optional[FlagEvaluationOptions] = None, + ) -> int: + details = await self.get_integer_details( + flag_key, + default_value, + evaluation_context, + flag_evaluation_options, + ) + return details.value + + async def get_integer_details( + self, + flag_key: str, + default_value: int, + evaluation_context: typing.Optional[EvaluationContext] = None, + flag_evaluation_options: typing.Optional[FlagEvaluationOptions] = None, + ) -> FlagEvaluationDetails[int]: + return await self.evaluate_flag_details( + FlagType.INTEGER, + flag_key, + default_value, + evaluation_context, + flag_evaluation_options, + ) + + async def get_float_value( + self, + flag_key: str, + default_value: float, + evaluation_context: typing.Optional[EvaluationContext] = None, + flag_evaluation_options: typing.Optional[FlagEvaluationOptions] = None, + ) -> float: + details = await self.get_float_details( + flag_key, + default_value, + evaluation_context, + flag_evaluation_options, + ) + return details.value + + async def get_float_details( + self, + flag_key: str, + default_value: float, + evaluation_context: typing.Optional[EvaluationContext] = None, + flag_evaluation_options: typing.Optional[FlagEvaluationOptions] = None, + ) -> FlagEvaluationDetails[float]: + return await self.evaluate_flag_details( + FlagType.FLOAT, + flag_key, + default_value, + evaluation_context, + flag_evaluation_options, + ) + + async def get_object_value( + self, + flag_key: str, + default_value: typing.Union[dict, list], + evaluation_context: typing.Optional[EvaluationContext] = None, + flag_evaluation_options: typing.Optional[FlagEvaluationOptions] = None, + ) -> typing.Union[dict, list]: + details = await self.get_object_details( + flag_key, + default_value, + evaluation_context, + flag_evaluation_options, + ) + return details.value + + async def get_object_details( + self, + flag_key: str, + default_value: typing.Union[dict, list], + evaluation_context: typing.Optional[EvaluationContext] = None, + flag_evaluation_options: typing.Optional[FlagEvaluationOptions] = None, + ) -> FlagEvaluationDetails[typing.Union[dict, list]]: + return await self.evaluate_flag_details( + FlagType.OBJECT, + flag_key, + default_value, + evaluation_context, + flag_evaluation_options, + ) + + async def evaluate_flag_details( # noqa: PLR0915 + self, + flag_type: FlagType, + flag_key: str, + default_value: typing.Any, + evaluation_context: typing.Optional[EvaluationContext] = None, + flag_evaluation_options: typing.Optional[FlagEvaluationOptions] = None, + ) -> FlagEvaluationDetails[typing.Any]: + """ + Evaluate the flag requested by the user from the clients provider. + + :param flag_type: the type of the flag being returned + :param flag_key: the string key of the selected flag + :param default_value: backup value returned if no result found by the provider + :param evaluation_context: Information for the purposes of flag evaluation + :param flag_evaluation_options: Additional flag evaluation information + :return: a FlagEvaluationDetails object with the fully evaluated flag from a + provider + """ + if evaluation_context is None: + evaluation_context = EvaluationContext() + + if flag_evaluation_options is None: + flag_evaluation_options = FlagEvaluationOptions() + + provider = self.provider # call this once to maintain a consistent reference + evaluation_hooks = flag_evaluation_options.hooks + hook_hints = flag_evaluation_options.hook_hints + + hook_context = HookContext( + flag_key=flag_key, + flag_type=flag_type, + default_value=default_value, + evaluation_context=evaluation_context, + client_metadata=self.get_metadata(), + provider_metadata=provider.get_metadata(), + ) + # Hooks need to be handled in different orders at different stages + # in the flag evaluation + # before: API, Client, Invocation, Provider + merged_hooks = ( + api.get_hooks() + + self.hooks + + evaluation_hooks + + provider.get_provider_hooks() + ) + # after, error, finally: Provider, Invocation, Client, API + reversed_merged_hooks = merged_hooks[:] + reversed_merged_hooks.reverse() + + status = self.get_provider_status() + if status == ProviderStatus.NOT_READY: + await error_hooks_async( + flag_type, + hook_context, + ProviderNotReadyError(), + reversed_merged_hooks, + hook_hints, + ) + return FlagEvaluationDetails( + flag_key=flag_key, + value=default_value, + reason=Reason.ERROR, + error_code=ErrorCode.PROVIDER_NOT_READY, + ) + if status == ProviderStatus.FATAL: + await error_hooks_async( + flag_type, + hook_context, + ProviderFatalError(), + reversed_merged_hooks, + hook_hints, + ) + return FlagEvaluationDetails( + flag_key=flag_key, + value=default_value, + reason=Reason.ERROR, + error_code=ErrorCode.PROVIDER_FATAL, + ) + + try: + # https://github.com/open-feature/spec/blob/main/specification/sections/03-evaluation-context.md + # Any resulting evaluation context from a before hook will overwrite + # duplicate fields defined globally, on the client, or in the invocation. + # Requirement 3.2.2, 4.3.4: API.context->client.context->invocation.context + invocation_context = await before_hooks_async( + flag_type, hook_context, merged_hooks, hook_hints + ) + + invocation_context = invocation_context.merge(ctx2=evaluation_context) + # Requirement 3.2.2 merge: API.context->client.context->invocation.context + merged_context = ( + api.get_evaluation_context() + .merge(self.context) + .merge(invocation_context) + ) + + flag_evaluation = await self._create_provider_evaluation( + provider, + flag_type, + flag_key, + default_value, + merged_context, + ) + + await after_hooks_async( + flag_type, + hook_context, + flag_evaluation, + reversed_merged_hooks, + hook_hints, + ) + + return flag_evaluation + + except OpenFeatureError as err: + await error_hooks_async( + flag_type, hook_context, err, reversed_merged_hooks, hook_hints + ) + + return FlagEvaluationDetails( + flag_key=flag_key, + value=default_value, + reason=Reason.ERROR, + error_code=err.error_code, + error_message=err.error_message, + ) + # Catch any type of exception here since the user can provide any exception + # in the error hooks + except Exception as err: # pragma: no cover + logger.exception( + "Unable to correctly evaluate flag with key: '%s'", flag_key + ) + + await error_hooks_async( + flag_type, hook_context, err, reversed_merged_hooks, hook_hints + ) + + error_message = getattr(err, "error_message", str(err)) + return FlagEvaluationDetails( + flag_key=flag_key, + value=default_value, + reason=Reason.ERROR, + error_code=ErrorCode.GENERAL, + error_message=error_message, + ) + + finally: + await after_all_hooks_async( + flag_type, hook_context, reversed_merged_hooks, hook_hints + ) + + async def _create_provider_evaluation( + self, + provider: FeatureProvider, + flag_type: FlagType, + flag_key: str, + default_value: typing.Any, + evaluation_context: typing.Optional[EvaluationContext] = None, + ) -> FlagEvaluationDetails[typing.Any]: + """ + Asynchronous encapsulated method to create a FlagEvaluationDetail from a specific provider. + + :param flag_type: the type of the flag being returned + :param key: the string key of the selected flag + :param default_value: backup value returned if no result found by the provider + :param evaluation_context: Information for the purposes of flag evaluation + :return: a FlagEvaluationDetails object with the fully evaluated flag from a + provider + """ + args = ( + flag_key, + default_value, + evaluation_context, + ) + + get_details_callables: typing.Mapping[FlagType, GetDetailCallable] = { + FlagType.BOOLEAN: provider.resolve_boolean_details, + FlagType.INTEGER: provider.resolve_integer_details, + FlagType.FLOAT: provider.resolve_float_details, + FlagType.OBJECT: provider.resolve_object_details, + FlagType.STRING: provider.resolve_string_details, + } + + get_details_callable = get_details_callables.get(flag_type) + if not get_details_callable: + raise GeneralError(error_message="Unknown flag type") + + resolution = await get_details_callable(*args) + resolution.raise_for_error() + + # we need to check the get_args to be compatible with union types. + _typecheck_flag_value(resolution.value, flag_type) + + return FlagEvaluationDetails( + flag_key=flag_key, + value=resolution.value, + variant=resolution.variant, + flag_metadata=resolution.flag_metadata or {}, + reason=resolution.reason, + error_code=resolution.error_code, + error_message=resolution.error_message, + ) diff --git a/openfeature/hook/__init__.py b/openfeature/hook/__init__.py index e98301fa..546d8cab 100644 --- a/openfeature/hook/__init__.py +++ b/openfeature/hook/__init__.py @@ -128,3 +128,66 @@ def supports_flag_value_type(self, flag_type: FlagType) -> bool: or not (False) """ return True + +class AsyncHook: + async def before( + self, hook_context: HookContext, hints: HookHints + ) -> typing.Optional[EvaluationContext]: + """ + Runs before flag is resolved. + + :param hook_context: Information about the particular flag evaluation + :param hints: An immutable mapping of data for users to + communicate to the hooks. + :return: An EvaluationContext. It will be merged with the + EvaluationContext instances from other hooks, the client and API. + """ + return None + + async def after( + self, + hook_context: HookContext, + details: FlagEvaluationDetails[typing.Any], + hints: HookHints, + ) -> None: + """ + Runs after a flag is resolved. + + :param hook_context: Information about the particular flag evaluation + :param details: Information about how the flag was resolved, + including any resolved values. + :param hints: A mapping of data for users to communicate to the hooks. + """ + pass + + async def error( + self, hook_context: HookContext, exception: Exception, hints: HookHints + ) -> None: + """ + Run when evaluation encounters an error. Errors thrown will be swallowed. + + :param hook_context: Information about the particular flag evaluation + :param exception: The exception that was thrown + :param hints: A mapping of data for users to communicate to the hooks. + """ + pass + + async def finally_after(self, hook_context: HookContext, hints: HookHints) -> None: + """ + Run after flag evaluation, including any error processing. + This will always run. Errors will be swallowed. + + :param hook_context: Information about the particular flag evaluation + :param hints: A mapping of data for users to communicate to the hooks. + """ + pass + + def supports_flag_value_type(self, flag_type: FlagType) -> bool: + """ + Check to see if the hook supports the particular flag type. + + :param flag_type: particular type of the flag + :return: a boolean containing whether the flag type is supported (True) + or not (False) + """ + return True diff --git a/openfeature/hook/_hook_support.py b/openfeature/hook/_hook_support.py index 349b25f3..ce817b36 100644 --- a/openfeature/hook/_hook_support.py +++ b/openfeature/hook/_hook_support.py @@ -22,6 +22,19 @@ def error_hooks( ) +async def error_hooks_async( + flag_type: FlagType, + hook_context: HookContext, + exception: Exception, + hooks: typing.List[Hook], + hints: typing.Optional[HookHints] = None, +) -> None: + kwargs = {"hook_context": hook_context, "exception": exception, "hints": hints} + await _execute_hooks_async( + flag_type=flag_type, hooks=hooks, hook_method=HookType.ERROR, **kwargs + ) + + def after_all_hooks( flag_type: FlagType, hook_context: HookContext, @@ -34,6 +47,18 @@ def after_all_hooks( ) +async def after_all_hooks_async( + flag_type: FlagType, + hook_context: HookContext, + hooks: typing.List[Hook], + hints: typing.Optional[HookHints] = None, +) -> None: + kwargs = {"hook_context": hook_context, "hints": hints} + await _execute_hooks_async( + flag_type=flag_type, hooks=hooks, hook_method=HookType.FINALLY_AFTER, **kwargs + ) + + def after_hooks( flag_type: FlagType, hook_context: HookContext, @@ -47,6 +72,19 @@ def after_hooks( ) +async def after_hooks_async( + flag_type: FlagType, + hook_context: HookContext, + details: FlagEvaluationDetails[typing.Any], + hooks: typing.List[Hook], + hints: typing.Optional[HookHints] = None, +) -> None: + kwargs = {"hook_context": hook_context, "details": details, "hints": hints} + await _execute_hooks_async_unchecked( + flag_type=flag_type, hooks=hooks, hook_method=HookType.AFTER, **kwargs + ) + + def before_hooks( flag_type: FlagType, hook_context: HookContext, @@ -65,6 +103,23 @@ def before_hooks( return EvaluationContext() +async def before_hooks_async( + flag_type: FlagType, + hook_context: HookContext, + hooks: typing.List[Hook], + hints: typing.Optional[HookHints] = None, +) -> EvaluationContext: + kwargs = {"hook_context": hook_context, "hints": hints} + executed_hooks = await _execute_hooks_async( + flag_type=flag_type, hooks=hooks, hook_method=HookType.BEFORE, **kwargs + ) + filtered_hooks = [result for result in executed_hooks if result is not None] + if filtered_hooks: + return reduce(lambda a, b: a.merge(b), filtered_hooks) + + return EvaluationContext() + + def _execute_hooks( flag_type: FlagType, hooks: typing.List[Hook], @@ -88,6 +143,29 @@ def _execute_hooks( ] +async def _execute_hooks_async( + flag_type: FlagType, + hooks: typing.List[Hook], + hook_method: HookType, + **kwargs: typing.Any, +) -> list: + """ + Run multiple hooks of any hook type. All of these hooks will be run through an + exception check. + + :param flag_type: particular type of flag + :param hooks: a list of hooks + :param hook_method: the type of hook that is being run + :param kwargs: arguments that need to be provided to the hook method + :return: a list of results from the applied hook methods + """ + return [ + await _execute_hook_checked_async(hook, hook_method, **kwargs) + for hook in hooks + if hook.supports_flag_value_type(flag_type) + ] + + def _execute_hooks_unchecked( flag_type: FlagType, hooks: typing.List[Hook], @@ -112,6 +190,30 @@ def _execute_hooks_unchecked( ] +async def _execute_hooks_async_unchecked( + flag_type: FlagType, + hooks: typing.List[Hook], + hook_method: HookType, + **kwargs: typing.Any, +) -> typing.List[typing.Optional[EvaluationContext]]: + """ + Execute a single hook without checking whether an exception is thrown. This is + used in the before and after hooks since any exception will be caught in the + client. + + :param flag_type: particular type of flag + :param hooks: a list of hooks + :param hook_method: the type of hook that is being run + :param kwargs: arguments that need to be provided to the hook method + :return: a list of results from the applied hook methods + """ + return [ + await getattr(hook, hook_method.value)(**kwargs) + for hook in hooks + if hook.supports_flag_value_type(flag_type) + ] + + def _execute_hook_checked( hook: Hook, hook_method: HookType, **kwargs: typing.Any ) -> typing.Optional[EvaluationContext]: @@ -132,3 +234,25 @@ def _execute_hook_checked( except Exception: # pragma: no cover logger.exception(f"Exception when running {hook_method.value} hooks") return None + + +async def _execute_hook_checked_async( + hook: Hook, hook_method: HookType, **kwargs: typing.Any +) -> typing.Optional[EvaluationContext]: + """ + Try and run a single hook and catch any exception thrown. This is used in the + after all and error hooks since any error thrown at this point needs to be caught. + + :param hook: a list of hooks + :param hook_method: the type of hook that is being run + :param kwargs: arguments that need to be provided to the hook method + :return: the result of the hook method + """ + try: + return typing.cast( + "typing.Optional[EvaluationContext]", + await getattr(hook, hook_method.value)(**kwargs), + ) + except Exception: # pragma: no cover + logger.exception(f"Exception when running {hook_method.value} hooks") + return None diff --git a/openfeature/provider/__init__.py b/openfeature/provider/__init__.py index 8927551e..f22a45ce 100644 --- a/openfeature/provider/__init__.py +++ b/openfeature/provider/__init__.py @@ -164,3 +164,50 @@ def emit_provider_stale(self, details: ProviderEventDetails) -> None: def emit(self, event: ProviderEvent, details: ProviderEventDetails) -> None: if hasattr(self, "_on_emit"): self._on_emit(self, event, details) + + +class AsyncAbstractProvider(AbstractProvider): + @abstractmethod + async def resolve_boolean_details( + self, + flag_key: str, + default_value: bool, + evaluation_context: typing.Optional[EvaluationContext] = None, + ) -> FlagResolutionDetails[bool]: + pass + + @abstractmethod + async def resolve_string_details( + self, + flag_key: str, + default_value: str, + evaluation_context: typing.Optional[EvaluationContext] = None, + ) -> FlagResolutionDetails[str]: + pass + + @abstractmethod + async def resolve_integer_details( + self, + flag_key: str, + default_value: int, + evaluation_context: typing.Optional[EvaluationContext] = None, + ) -> FlagResolutionDetails[int]: + pass + + @abstractmethod + async def resolve_float_details( + self, + flag_key: str, + default_value: float, + evaluation_context: typing.Optional[EvaluationContext] = None, + ) -> FlagResolutionDetails[float]: + pass + + @abstractmethod + def resolve_object_details( + self, + flag_key: str, + default_value: typing.Union[dict, list], + evaluation_context: typing.Optional[EvaluationContext] = None, + ) -> FlagResolutionDetails[typing.Union[dict, list]]: + pass diff --git a/openfeature/provider/in_memory_provider.py b/openfeature/provider/in_memory_provider.py index 322f4ed6..7debd0d3 100644 --- a/openfeature/provider/in_memory_provider.py +++ b/openfeature/provider/in_memory_provider.py @@ -117,3 +117,56 @@ def _resolve( if flag is None: raise FlagNotFoundError(f"Flag '{flag_key}' not found") return flag.resolve(evaluation_context) + + +class AsyncInMemoryProvider(InMemoryProvider): + _flags: FlagStorage + + def __init__(self, flags: FlagStorage) -> None: + self._flags = flags.copy() + + def get_metadata(self) -> Metadata: + return InMemoryMetadata() + + def get_provider_hooks(self) -> typing.List[Hook]: + return [] + + async def resolve_boolean_details( + self, + flag_key: str, + default_value: bool, + evaluation_context: typing.Optional[EvaluationContext] = None, + ) -> FlagResolutionDetails[bool]: + return self._resolve(flag_key, evaluation_context) + + async def resolve_string_details( + self, + flag_key: str, + default_value: str, + evaluation_context: typing.Optional[EvaluationContext] = None, + ) -> FlagResolutionDetails[str]: + return self._resolve(flag_key, evaluation_context) + + async def resolve_integer_details( + self, + flag_key: str, + default_value: int, + evaluation_context: typing.Optional[EvaluationContext] = None, + ) -> FlagResolutionDetails[int]: + return self._resolve(flag_key, evaluation_context) + + async def resolve_float_details( + self, + flag_key: str, + default_value: float, + evaluation_context: typing.Optional[EvaluationContext] = None, + ) -> FlagResolutionDetails[float]: + return self._resolve(flag_key, evaluation_context) + + async def resolve_object_details( + self, + flag_key: str, + default_value: typing.Union[dict, list], + evaluation_context: typing.Optional[EvaluationContext] = None, + ) -> FlagResolutionDetails[typing.Union[dict, list]]: + return self._resolve(flag_key, evaluation_context) \ No newline at end of file diff --git a/openfeature/provider/no_op_provider.py b/openfeature/provider/no_op_provider.py index 070945c9..cba1d4fb 100644 --- a/openfeature/provider/no_op_provider.py +++ b/openfeature/provider/no_op_provider.py @@ -75,3 +75,65 @@ def resolve_object_details( reason=Reason.DEFAULT, variant=PASSED_IN_DEFAULT, ) + + +class AsyncNoOpProvider(NoOpProvider): + async def resolve_boolean_details( + self, + flag_key: str, + default_value: bool, + evaluation_context: typing.Optional[EvaluationContext] = None, + ) -> FlagResolutionDetails[bool]: + return FlagResolutionDetails( + value=default_value, + reason=Reason.DEFAULT, + variant=PASSED_IN_DEFAULT, + ) + + async def resolve_string_details( + self, + flag_key: str, + default_value: str, + evaluation_context: typing.Optional[EvaluationContext] = None, + ) -> FlagResolutionDetails[str]: + return FlagResolutionDetails( + value=default_value, + reason=Reason.DEFAULT, + variant=PASSED_IN_DEFAULT, + ) + + async def resolve_integer_details( + self, + flag_key: str, + default_value: int, + evaluation_context: typing.Optional[EvaluationContext] = None, + ) -> FlagResolutionDetails[int]: + return FlagResolutionDetails( + value=default_value, + reason=Reason.DEFAULT, + variant=PASSED_IN_DEFAULT, + ) + + async def resolve_float_details( + self, + flag_key: str, + default_value: float, + evaluation_context: typing.Optional[EvaluationContext] = None, + ) -> FlagResolutionDetails[float]: + return FlagResolutionDetails( + value=default_value, + reason=Reason.DEFAULT, + variant=PASSED_IN_DEFAULT, + ) + + async def resolve_object_details( + self, + flag_key: str, + default_value: typing.Union[dict, list], + evaluation_context: typing.Optional[EvaluationContext] = None, + ) -> FlagResolutionDetails[typing.Union[dict, list]]: + return FlagResolutionDetails( + value=default_value, + reason=Reason.DEFAULT, + variant=PASSED_IN_DEFAULT, + ) diff --git a/pyproject.toml b/pyproject.toml index 40887d9d..f17eaba6 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -34,6 +34,7 @@ dependencies = [ "behave", "coverage[toml]>=6.5", "pytest", + "pytest-asyncio", ] [tool.hatch.envs.default.scripts] diff --git a/tests/conftest.py b/tests/conftest.py index 1f0a7982..b3dcc16a 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -1,7 +1,7 @@ import pytest from openfeature import api -from openfeature.provider.no_op_provider import NoOpProvider +from openfeature.provider.no_op_provider import AsyncNoOpProvider, NoOpProvider @pytest.fixture(autouse=True) @@ -13,7 +13,22 @@ def clear_providers(): api.clear_providers() +@pytest.fixture(autouse=True) +def clear_hooks_fixture(): + """ + For tests that use add_hooks(), we need to clear the hooks to avoid issues + in other tests. + """ + api.clear_hooks() + + @pytest.fixture() def no_op_provider_client(): api.set_provider(NoOpProvider()) return api.get_client() + + +@pytest.fixture() +def no_op_provider_client_async(): + api.set_provider(AsyncNoOpProvider()) + return api.get_client_async("my-async-client") diff --git a/tests/hook/conftest.py b/tests/hook/conftest.py index 5a3d8092..bf4af84e 100644 --- a/tests/hook/conftest.py +++ b/tests/hook/conftest.py @@ -3,6 +3,7 @@ import pytest from openfeature.evaluation_context import EvaluationContext +from openfeature.hook import AsyncHook @pytest.fixture() @@ -14,3 +15,14 @@ def mock_hook(): mock_hook.error.return_value = None mock_hook.finally_after.return_value = None return mock_hook + + +@pytest.fixture() +def mock_hook_async(): + mock_hook = AsyncHook() + mock_hook.supports_flag_value_type = mock.MagicMock(return_value=True) + mock_hook.before = mock.AsyncMock(return_value=None) + mock_hook.after = mock.AsyncMock(return_value=None) + mock_hook.error = mock.AsyncMock(return_value=None) + mock_hook.finally_after = mock.AsyncMock(return_value=None) + return mock_hook diff --git a/tests/hook/test_hook_support.py b/tests/hook/test_hook_support.py index 64bb8f6f..49997107 100644 --- a/tests/hook/test_hook_support.py +++ b/tests/hook/test_hook_support.py @@ -1,16 +1,20 @@ -from unittest.mock import ANY, MagicMock +from unittest.mock import ANY, AsyncMock, MagicMock import pytest from openfeature.client import ClientMetadata from openfeature.evaluation_context import EvaluationContext from openfeature.flag_evaluation import FlagEvaluationDetails, FlagType -from openfeature.hook import Hook, HookContext +from openfeature.hook import AsyncHook, Hook, HookContext from openfeature.hook._hook_support import ( after_all_hooks, + after_all_hooks_async, after_hooks, + after_hooks_async, before_hooks, + before_hooks_async, error_hooks, + error_hooks_async, ) from openfeature.immutable_dict.mapping_proxy_type import MappingProxyType from openfeature.provider.metadata import Metadata @@ -86,6 +90,23 @@ def test_error_hooks_run_error_method(mock_hook): ) +@pytest.mark.asyncio +async def test_error_hooks_run_error_method_async(mock_hook_async): + # Given + hook_context = HookContext("flag_key", FlagType.BOOLEAN, True, "") + hook_hints = MappingProxyType({}) + # When + await error_hooks_async( + FlagType.BOOLEAN, hook_context, Exception, [mock_hook_async], hook_hints + ) + # Then + mock_hook_async.supports_flag_value_type.assert_called_once() + mock_hook_async.error.assert_called_once() + mock_hook_async.error.assert_called_with( + hook_context=hook_context, exception=ANY, hints=hook_hints + ) + + def test_before_hooks_run_before_method(mock_hook): # Given hook_context = HookContext("flag_key", FlagType.BOOLEAN, True, "") @@ -98,6 +119,23 @@ def test_before_hooks_run_before_method(mock_hook): mock_hook.before.assert_called_with(hook_context=hook_context, hints=hook_hints) +@pytest.mark.asyncio +async def test_before_hooks_run_before_method_async(mock_hook_async): + # Given + hook_context = HookContext("flag_key", FlagType.BOOLEAN, True, "") + hook_hints = MappingProxyType({}) + # When + await before_hooks_async( + FlagType.BOOLEAN, hook_context, [mock_hook_async], hook_hints + ) + # Then + mock_hook_async.supports_flag_value_type.assert_called_once() + mock_hook_async.before.assert_called_once() + mock_hook_async.before.assert_called_with( + hook_context=hook_context, hints=hook_hints + ) + + def test_before_hooks_merges_evaluation_contexts(): # Given hook_context = HookContext("flag_key", FlagType.BOOLEAN, True, "") @@ -115,6 +153,25 @@ def test_before_hooks_merges_evaluation_contexts(): assert context == EvaluationContext("bar", {"key_1": "val_1", "key_2": "val_2"}) +@pytest.mark.asyncio +async def test_before_hooks_async_merges_evaluation_contexts(): + # Given + hook_context = HookContext("flag_key", FlagType.BOOLEAN, True, "") + hook_1 = AsyncHook() + hook_1.before = AsyncMock(return_value=EvaluationContext("foo", {"key_1": "val_1"})) + hook_2 = AsyncHook() + hook_2.before = AsyncMock(return_value=EvaluationContext("bar", {"key_2": "val_2"})) + hook_3 = AsyncHook() + hook_3.before = AsyncMock(return_value=None) + # When + context = await before_hooks_async( + FlagType.BOOLEAN, hook_context, [hook_1, hook_2, hook_3] + ) + + # Then + assert context == EvaluationContext("bar", {"key_1": "val_1", "key_2": "val_2"}) + + def test_after_hooks_run_after_method(mock_hook): # Given hook_context = HookContext("flag_key", FlagType.BOOLEAN, True, "") @@ -134,6 +191,30 @@ def test_after_hooks_run_after_method(mock_hook): ) +@pytest.mark.asyncio +async def test_after_hooks_run_after_method_async(mock_hook_async): + # Given + hook_context = HookContext("flag_key", FlagType.BOOLEAN, True, "") + flag_evaluation_details = FlagEvaluationDetails( + hook_context.flag_key, "val", "unknown" + ) + hook_hints = MappingProxyType({}) + # When + await after_hooks_async( + FlagType.BOOLEAN, + hook_context, + flag_evaluation_details, + [mock_hook_async], + hook_hints, + ) + # Then + mock_hook_async.supports_flag_value_type.assert_called_once() + mock_hook_async.after.assert_called_once() + mock_hook_async.after.assert_called_with( + hook_context=hook_context, details=flag_evaluation_details, hints=hook_hints + ) + + def test_finally_after_hooks_run_finally_after_method(mock_hook): # Given hook_context = HookContext("flag_key", FlagType.BOOLEAN, True, "") @@ -146,3 +227,20 @@ def test_finally_after_hooks_run_finally_after_method(mock_hook): mock_hook.finally_after.assert_called_with( hook_context=hook_context, hints=hook_hints ) + + +@pytest.mark.asyncio +async def test_finally_after_hooks_run_finally_after_method_async(mock_hook_async): + # Given + hook_context = HookContext("flag_key", FlagType.BOOLEAN, True, "") + hook_hints = MappingProxyType({}) + # When + await after_all_hooks_async( + FlagType.BOOLEAN, hook_context, [mock_hook_async], hook_hints + ) + # Then + mock_hook_async.supports_flag_value_type.assert_called_once() + mock_hook_async.finally_after.assert_called_once() + mock_hook_async.finally_after.assert_called_with( + hook_context=hook_context, hints=hook_hints + ) diff --git a/tests/test_async_client.py b/tests/test_async_client.py new file mode 100644 index 00000000..97a11bc7 --- /dev/null +++ b/tests/test_async_client.py @@ -0,0 +1,195 @@ +import time +import uuid +from concurrent.futures import ThreadPoolExecutor +from unittest.mock import AsyncMock, MagicMock + +import pytest + +from openfeature.api import add_hooks, get_client_async, set_provider +from openfeature.client import AsyncOpenFeatureClient +from openfeature.event import ProviderEvent, ProviderEventDetails +from openfeature.exception import ErrorCode, OpenFeatureError +from openfeature.flag_evaluation import FlagResolutionDetails, Reason +from openfeature.hook import AsyncHook +from openfeature.provider import FeatureProvider, ProviderStatus +from openfeature.provider.in_memory_provider import AsyncInMemoryProvider, InMemoryFlag +from openfeature.provider.no_op_provider import NoOpProvider + +async_hook = MagicMock(spec=AsyncHook) +async_hook.before = AsyncMock(return_value=None) +async_hook.after = AsyncMock(return_value=None) + + +@pytest.mark.parametrize( + "default, variants, get_method, expected_value", + ( + ("true", {"true": True, "false": False}, "get_boolean", True), + ("String", {"String": "Variant"}, "get_string", "Variant"), + ("Number", {"Number": 100}, "get_integer", 100), + ("Float", {"Float": 10.23}, "get_float", 10.23), + ( + "Object", + {"Object": {"some": "object"}}, + "get_object", + {"some": "object"}, + ), + ), +) +@pytest.mark.asyncio +async def test_flag_resolution_to_evaluation_details_async( + default, variants, get_method, expected_value, clear_hooks_fixture +): + # Given + add_hooks([async_hook]) + provider = AsyncInMemoryProvider( + { + "Key": InMemoryFlag( + default, + variants, + flag_metadata={"foo": "bar"}, + ) + } + ) + set_provider(provider, "my-async-client") + client = AsyncOpenFeatureClient("my-async-client", None) + client.add_hooks([async_hook]) + # When + details = await getattr(client, f"{get_method}_details")( + flag_key="Key", default_value=None + ) + value = await getattr(client, f"{get_method}_value")( + flag_key="Key", default_value=None + ) + # Then + assert details is not None + assert details.flag_metadata == {"foo": "bar"} + assert details.value == expected_value + assert details.value == value + + +@pytest.mark.asyncio +async def test_should_return_client_metadata_with_domain_async( + no_op_provider_client_async, +): + # Given + # When + metadata = no_op_provider_client_async.get_metadata() + # Then + assert metadata is not None + assert metadata.domain == "my-async-client" + + +def test_add_remove_event_handler_async(): + # Given + provider = NoOpProvider() + set_provider(provider) + spy = MagicMock() + + client = get_client_async() + client.add_handler( + ProviderEvent.PROVIDER_CONFIGURATION_CHANGED, spy.provider_configuration_changed + ) + client.remove_handler( + ProviderEvent.PROVIDER_CONFIGURATION_CHANGED, spy.provider_configuration_changed + ) + + provider_details = ProviderEventDetails(message="message") + + # When + provider.emit_provider_configuration_changed(provider_details) + + # Then + spy.provider_configuration_changed.assert_not_called() + + +def test_client_handlers_thread_safety_async(): + provider = NoOpProvider() + set_provider(provider) + + def add_handlers_task(): + def handler(*args, **kwargs): + time.sleep(0.005) + + for _ in range(10): + time.sleep(0.01) + client = get_client_async(str(uuid.uuid4())) + client.add_handler(ProviderEvent.PROVIDER_CONFIGURATION_CHANGED, handler) + + def emit_events_task(): + for _ in range(10): + time.sleep(0.01) + provider.emit_provider_configuration_changed(ProviderEventDetails()) + + with ThreadPoolExecutor(max_workers=2) as executor: + f1 = executor.submit(add_handlers_task) + f2 = executor.submit(emit_events_task) + f1.result() + f2.result() + + +@pytest.mark.parametrize( + "provider_status, error_code", + ( + (ProviderStatus.NOT_READY, ErrorCode.PROVIDER_NOT_READY), + (ProviderStatus.FATAL, ErrorCode.PROVIDER_FATAL), + ), +) +@pytest.mark.asyncio +async def test_should_shortcircuit_if_provider_is_not_ready( + no_op_provider_client_async, monkeypatch, provider_status, error_code +): + # Given + monkeypatch.setattr( + no_op_provider_client_async, + "get_provider_status", + lambda: provider_status, + ) + spy_hook = MagicMock(spec=AsyncHook) + spy_hook.before = AsyncMock(return_value=None) + no_op_provider_client_async.add_hooks([spy_hook]) + # When + flag_details = await no_op_provider_client_async.get_boolean_details( + flag_key="Key", default_value=True + ) + # Then + assert flag_details is not None + assert flag_details.value + assert flag_details.reason == Reason.ERROR + assert flag_details.error_code == error_code + spy_hook.error.assert_called_once() + + +@pytest.mark.parametrize( + "expected_type, get_method, default_value", + ( + (bool, "get_boolean_details", True), + (str, "get_string_details", "default"), + (int, "get_integer_details", 100), + (float, "get_float_details", 10.23), + (dict, "get_object_details", {"some": "object"}), + ), +) +@pytest.mark.asyncio +async def test_handle_an_open_feature_exception_thrown_by_a_provider_async( + expected_type, + get_method, + default_value, + no_op_provider_client_async, +): + # Given + exception_hook = AsyncHook() + exception_hook.after = AsyncMock( + side_effect=OpenFeatureError(ErrorCode.GENERAL, "error_message") + ) + no_op_provider_client_async.add_hooks([exception_hook]) + + # When + flag_details = await getattr(no_op_provider_client_async, get_method)( + flag_key="Key", default_value=default_value + ) + # Then + assert flag_details is not None + assert flag_details.value + assert isinstance(flag_details.value, expected_type) + assert flag_details.reason == Reason.ERROR + assert flag_details.error_message == "error_message"