diff --git a/README.md b/README.md index 5d4cd896..17e38eea 100644 --- a/README.md +++ b/README.md @@ -130,6 +130,15 @@ api.set_provider(NoOpProvider()) open_feature_client = api.get_client() ``` +`set_provider()` is non-blocking: it registers the provider immediately and runs initialization in a background thread. +Flag evaluations during the initialization window return the default value with a `PROVIDER_NOT_READY` error code. +Use `set_provider_and_wait()` if you need to ensure the provider is ready before proceeding: + +```python +# blocks until the provider is initialized (or raises on failure) +api.set_provider_and_wait(NoOpProvider()) +``` + In some situations, it may be beneficial to register multiple providers in the same application. This is possible using [domains](#domains), which is covered in more detail below. diff --git a/openfeature/api.py b/openfeature/api.py index 817104ab..4585e50e 100644 --- a/openfeature/api.py +++ b/openfeature/api.py @@ -33,6 +33,7 @@ "remove_handler", "set_evaluation_context", "set_provider", + "set_provider_and_wait", "set_transaction_context", "set_transaction_context_propagator", "shutdown", @@ -52,6 +53,13 @@ def set_provider(provider: FeatureProvider, domain: str | None = None) -> None: provider_registry.set_provider(domain, provider) +def set_provider_and_wait(provider: FeatureProvider, domain: str | None = None) -> None: + if domain is None: + provider_registry.set_default_provider(provider, wait_for_init=True) + else: + provider_registry.set_provider(domain, provider, wait_for_init=True) + + def clear_providers() -> None: provider_registry.clear_providers() _event_support.clear() diff --git a/openfeature/provider/_registry.py b/openfeature/provider/_registry.py index bf8fa9a8..34366a99 100644 --- a/openfeature/provider/_registry.py +++ b/openfeature/provider/_registry.py @@ -1,3 +1,5 @@ +import threading + from openfeature._event_support import run_handlers_for_provider from openfeature.evaluation_context import EvaluationContext, get_evaluation_context from openfeature.event import ( @@ -21,7 +23,9 @@ def __init__(self) -> None: self._default_provider: ProviderStatus.READY, } - def set_provider(self, domain: str, provider: FeatureProvider) -> None: + def set_provider( + self, domain: str, provider: FeatureProvider, wait_for_init: bool = False + ) -> None: if provider is None: raise GeneralError(error_message="No provider") if domain is None: @@ -36,7 +40,7 @@ def set_provider(self, domain: str, provider: FeatureProvider) -> None: ): self._shutdown_provider(old_provider) if provider != self._default_provider and provider not in providers.values(): - self._initialize_provider(provider) + self._initialize_provider(provider, wait_for_init=wait_for_init) providers[domain] = provider def get_provider(self, domain: str | None) -> FeatureProvider: @@ -44,7 +48,9 @@ def get_provider(self, domain: str | None) -> FeatureProvider: return self._default_provider return self._providers.get(domain, self._default_provider) - def set_default_provider(self, provider: FeatureProvider) -> None: + def set_default_provider( + self, provider: FeatureProvider, wait_for_init: bool = False + ) -> None: if provider is None: raise GeneralError(error_message="No provider") if ( @@ -55,7 +61,7 @@ def set_default_provider(self, provider: FeatureProvider) -> None: self._default_provider = provider if self._default_provider not in self._providers.values(): - self._initialize_provider(provider) + self._initialize_provider(provider, wait_for_init=wait_for_init) def get_default_provider(self) -> FeatureProvider: return self._default_provider @@ -75,8 +81,24 @@ def shutdown(self) -> None: def _get_evaluation_context(self) -> EvaluationContext: return get_evaluation_context() - def _initialize_provider(self, provider: FeatureProvider) -> None: + def _initialize_provider( + self, provider: FeatureProvider, wait_for_init: bool = False + ) -> None: provider.attach(self.dispatch_event) + if wait_for_init: + self._run_initialize(provider, raise_on_error=True) + else: + thread = threading.Thread( + target=self._run_initialize, + args=(provider,), + kwargs={"raise_on_error": False}, + daemon=True, + ) + thread.start() + + def _run_initialize( + self, provider: FeatureProvider, raise_on_error: bool = False + ) -> None: try: if hasattr(provider, "initialize"): provider.initialize(self._get_evaluation_context()) @@ -97,6 +119,8 @@ def _initialize_provider(self, provider: FeatureProvider) -> None: error_code=error_code, ), ) + if raise_on_error: + raise def _shutdown_provider(self, provider: FeatureProvider) -> None: try: diff --git a/tests/conftest.py b/tests/conftest.py index 1f0a7982..495634c1 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -15,5 +15,5 @@ def clear_providers(): @pytest.fixture() def no_op_provider_client(): - api.set_provider(NoOpProvider()) + api.set_provider_and_wait(NoOpProvider()) return api.get_client() diff --git a/tests/features/steps/metadata_steps.py b/tests/features/steps/metadata_steps.py index 0154a9f0..bed87d17 100644 --- a/tests/features/steps/metadata_steps.py +++ b/tests/features/steps/metadata_steps.py @@ -1,13 +1,13 @@ from behave import given, then -from openfeature.api import get_client, set_provider +from openfeature.api import get_client, set_provider_and_wait from openfeature.provider.in_memory_provider import InMemoryProvider from tests.features.data import IN_MEMORY_FLAGS @given("a stable provider") def step_impl_stable_provider(context): - set_provider(InMemoryProvider(IN_MEMORY_FLAGS)) + set_provider_and_wait(InMemoryProvider(IN_MEMORY_FLAGS)) context.client = get_client() diff --git a/tests/features/steps/steps.py b/tests/features/steps/steps.py index 5d9d38fd..9b699331 100644 --- a/tests/features/steps/steps.py +++ b/tests/features/steps/steps.py @@ -4,7 +4,7 @@ from behave import given, then, when -from openfeature.api import get_client, set_provider +from openfeature.api import get_client, set_provider_and_wait from openfeature.client import OpenFeatureClient from openfeature.evaluation_context import EvaluationContext from openfeature.exception import ErrorCode @@ -28,13 +28,13 @@ def step_impl_resolved_should_be(context, flag_type, key, expected_reason): @given("a provider is registered with cache disabled") def step_impl_provider_without_cache(context): - set_provider(InMemoryProvider(IN_MEMORY_FLAGS)) + set_provider_and_wait(InMemoryProvider(IN_MEMORY_FLAGS)) context.client = get_client() @given("a provider is registered") def step_impl_provider(context): - set_provider(InMemoryProvider(IN_MEMORY_FLAGS)) + set_provider_and_wait(InMemoryProvider(IN_MEMORY_FLAGS)) context.client = get_client() diff --git a/tests/provider/test_registry.py b/tests/provider/test_registry.py index b5e10503..1326170f 100644 --- a/tests/provider/test_registry.py +++ b/tests/provider/test_registry.py @@ -1,8 +1,9 @@ +import threading from unittest.mock import Mock import pytest -from openfeature.exception import GeneralError +from openfeature.exception import GeneralError, ProviderFatalError from openfeature.provider import ProviderStatus from openfeature.provider._registry import ProviderRegistry from openfeature.provider.no_op_provider import NoOpProvider @@ -67,8 +68,8 @@ def test_registering_provider_for_first_time_initializes_it(): registry = ProviderRegistry() provider = Mock() - registry.set_provider("domain1", provider) - registry.set_provider("domain2", provider) + registry.set_provider("domain1", provider, wait_for_init=True) + registry.set_provider("domain2", provider, wait_for_init=True) provider.initialize.assert_called_once() @@ -103,7 +104,7 @@ def test_setting_default_provider_initializes_it(): registry = ProviderRegistry() provider = Mock() - registry.set_default_provider(provider) + registry.set_default_provider(provider, wait_for_init=True) provider.initialize.assert_called_once() @@ -114,8 +115,8 @@ def test_registering_provider_as_default_then_domain_only_initializes_once(): registry = ProviderRegistry() provider = Mock() - registry.set_default_provider(provider) - registry.set_provider("domain", provider) + registry.set_default_provider(provider, wait_for_init=True) + registry.set_provider("domain", provider, wait_for_init=True) provider.initialize.assert_called_once() @@ -126,8 +127,8 @@ def test_registering_provider_as_domain_then_default_only_initializes_once(): registry = ProviderRegistry() provider = Mock() - registry.set_provider("domain", provider) - registry.set_default_provider(provider) + registry.set_provider("domain", provider, wait_for_init=True) + registry.set_default_provider(provider, wait_for_init=True) provider.initialize.assert_called_once() @@ -191,7 +192,7 @@ def test_initializing_provider_sets_status_ready(): assert registry.get_provider_status(provider) == ProviderStatus.NOT_READY - registry.set_provider("domain", provider) + registry.set_provider("domain", provider, wait_for_init=True) provider.initialize.assert_called_once() assert registry.get_provider_status(provider) == ProviderStatus.READY @@ -203,7 +204,7 @@ def test_shutting_down_provider_sets_status_not_ready(): registry = ProviderRegistry() provider = Mock() - registry.set_provider("domain", provider) + registry.set_provider("domain", provider, wait_for_init=True) assert registry.get_provider_status(provider) == ProviderStatus.READY registry.shutdown() @@ -216,8 +217,8 @@ def test_clearing_registry_resets_providers_and_default(): registry = ProviderRegistry() provider = Mock() - registry.set_provider("domain", provider) - registry.set_default_provider(provider) + registry.set_provider("domain", provider, wait_for_init=True) + registry.set_default_provider(provider, wait_for_init=True) registry.clear_providers() @@ -228,3 +229,53 @@ def test_clearing_registry_resets_providers_and_default(): provider.initialize.assert_called_once() provider.shutdown.assert_called_once() + + +def test_set_provider_returns_before_initialization_completes(): + """Test that set_provider (non-blocking) returns before initialize finishes.""" + + registry = ProviderRegistry() + init_started = threading.Event() + init_may_proceed = threading.Event() + provider = Mock() + + def slow_initialize(ctx): + init_started.set() + init_may_proceed.wait() + + provider.initialize.side_effect = slow_initialize + + registry.set_provider("domain", provider) + + assert init_started.wait(timeout=2), "initialize was never called in background" + assert registry.get_provider_status(provider) == ProviderStatus.NOT_READY + + init_may_proceed.set() # unblock the background thread + + +def test_set_provider_and_wait_blocks_until_ready(): + """Test that set_provider with wait_for_init=True blocks until READY.""" + + registry = ProviderRegistry() + initialized = threading.Event() + provider = Mock() + + def tracking_initialize(ctx): + initialized.set() + + provider.initialize.side_effect = tracking_initialize + + registry.set_provider("domain", provider, wait_for_init=True) + + assert initialized.is_set() + assert registry.get_provider_status(provider) == ProviderStatus.READY + + +def test_set_provider_and_wait_reraises_on_error(): + """Test that set_provider with wait_for_init=True re-raises initialization errors.""" + registry = ProviderRegistry() + provider = Mock() + provider.initialize.side_effect = ProviderFatalError() + + with pytest.raises(ProviderFatalError): + registry.set_provider("domain", provider, wait_for_init=True) diff --git a/tests/test_api.py b/tests/test_api.py index cacdf694..b7945cbb 100644 --- a/tests/test_api.py +++ b/tests/test_api.py @@ -1,3 +1,4 @@ +import threading from unittest.mock import MagicMock import pytest @@ -14,6 +15,7 @@ remove_handler, set_evaluation_context, set_provider, + set_provider_and_wait, shutdown, ) from openfeature.evaluation_context import EvaluationContext @@ -69,7 +71,7 @@ def test_should_invoke_provider_initialize_function_on_newly_registered_provider # When set_evaluation_context(evaluation_context) - set_provider(provider) + set_provider_and_wait(provider) # Then provider.initialize.assert_called_with(evaluation_context) @@ -170,10 +172,10 @@ def test_should_provide_a_function_to_bind_provider_through_domain(): def test_should_not_initialize_provider_already_bound_to_another_domain(): # Given provider = MagicMock(spec=FeatureProvider) - set_provider(provider, "foo") + set_provider_and_wait(provider, "foo") # When - set_provider(provider, "bar") + set_provider_and_wait(provider, "bar") # Then provider.initialize.assert_called_once() @@ -326,7 +328,7 @@ def test_add_remove_event_handler(): def test_handlers_attached_to_provider_already_in_associated_state_should_run_immediately(): # Given provider = NoOpProvider() - set_provider(provider) + set_provider_and_wait(provider) spy = MagicMock() # When @@ -345,7 +347,7 @@ def test_provider_ready_handlers_run_if_provider_initialize_function_terminates_ spy.reset_mock() # reset the mock to avoid counting the immediate call on subscribe # When - set_provider(provider) + set_provider_and_wait(provider) # Then spy.provider_ready.assert_called_once() @@ -360,7 +362,8 @@ def test_provider_error_handlers_run_if_provider_initialize_function_terminates_ add_handler(ProviderEvent.PROVIDER_ERROR, spy.provider_error) # When - set_provider(provider) + with pytest.raises(ProviderFatalError): + set_provider_and_wait(provider) # Then spy.provider_error.assert_called_once() @@ -369,7 +372,7 @@ def test_provider_error_handlers_run_if_provider_initialize_function_terminates_ def test_provider_status_is_updated_after_provider_emits_event(): # Given provider = NoOpProvider() - set_provider(provider) + set_provider_and_wait(provider) client = get_client() # When @@ -393,3 +396,103 @@ def test_provider_status_is_updated_after_provider_emits_event(): provider.emit_provider_ready(ProviderEventDetails()) # Then assert client.get_provider_status() == ProviderStatus.READY + + +# Non-blocking set_provider tests + + +def test_set_provider_returns_before_initialization_completes(): + # Given: a provider whose initialize blocks until signalled + init_started = threading.Event() + init_may_proceed = threading.Event() + + provider = MagicMock(spec=FeatureProvider) + + def slow_initialize(ctx): + init_started.set() + init_may_proceed.wait() + + provider.initialize.side_effect = slow_initialize + + # When + set_provider(provider) + + # Then: set_provider returned before initialize completed (we reached this line + # while the background thread is still blocked inside initialize) + assert init_started.wait(timeout=2), "initialize was never called" + init_may_proceed.set() # unblock the background thread + + +def test_provider_status_is_not_ready_during_async_initialization(): + # Given: a provider whose initialize blocks until signalled + init_may_proceed = threading.Event() + provider = MagicMock(spec=FeatureProvider) + + def slow_initialize(ctx): + init_may_proceed.wait() + + provider.initialize.side_effect = slow_initialize + + # When + set_provider(provider) + client = get_client() + + # Then: status is NOT_READY while init is still running + assert client.get_provider_status() == ProviderStatus.NOT_READY + + # Cleanup: let the background thread finish + init_may_proceed.set() + + +def test_set_provider_and_wait_blocks_until_initialization_completes(): + # Given + initialized = threading.Event() + provider = MagicMock(spec=FeatureProvider) + + def slow_initialize(ctx): + initialized.set() + + provider.initialize.side_effect = slow_initialize + + # When + set_provider_and_wait(provider) + + # Then: initialize was called before set_provider_and_wait returned + assert initialized.is_set() + assert get_client().get_provider_status() == ProviderStatus.READY + + +def test_set_provider_and_wait_reraises_on_failure(): + # Given + provider = MagicMock(spec=FeatureProvider) + provider.initialize.side_effect = ProviderFatalError() + + # When / Then + with pytest.raises(ProviderFatalError): + set_provider_and_wait(provider) + + +def test_set_provider_swallows_error_and_emits_provider_error_event(): + # Given + provider = MagicMock(spec=FeatureProvider) + error_fired = threading.Event() + + def failing_initialize(ctx): + raise ProviderFatalError() + + provider.initialize.side_effect = failing_initialize + + spy = MagicMock() + + def on_error(details): + spy.on_error(details) + error_fired.set() + + add_handler(ProviderEvent.PROVIDER_ERROR, on_error) + + # When: non-blocking set_provider — must not raise + set_provider(provider) + + # Then: error event fired, exception was not propagated + assert error_fired.wait(timeout=2), "PROVIDER_ERROR event was never fired" + spy.on_error.assert_called_once()