diff --git a/openfeature/context_aware_test.go b/openfeature/context_aware_test.go new file mode 100644 index 00000000..f8c470d2 --- /dev/null +++ b/openfeature/context_aware_test.go @@ -0,0 +1,803 @@ +package openfeature + +import ( + "context" + "errors" + "testing" + "time" +) + +// testContextAwareProvider is a test provider that implements ContextAwareStateHandler +type testContextAwareProvider struct { + initDelay time.Duration +} + +func (p *testContextAwareProvider) Metadata() Metadata { + return Metadata{Name: "test-context-aware-provider"} +} + +// InitWithContext implements ContextAwareStateHandler +func (p *testContextAwareProvider) InitWithContext(ctx context.Context, evalCtx EvaluationContext) error { + select { + case <-time.After(p.initDelay): + return nil + case <-ctx.Done(): + return ctx.Err() + } +} + +// Init implements StateHandler for backward compatibility +func (p *testContextAwareProvider) Init(evalCtx EvaluationContext) error { + return p.InitWithContext(context.Background(), evalCtx) +} + +// ShutdownWithContext implements ContextAwareStateHandler +func (p *testContextAwareProvider) ShutdownWithContext(ctx context.Context) error { + select { + case <-time.After(p.initDelay): // Reuse delay for shutdown simulation + return nil + case <-ctx.Done(): + return ctx.Err() + } +} + +// Shutdown implements StateHandler for backward compatibility +func (p *testContextAwareProvider) Shutdown() { + // For backward compatibility, use background context with no timeout + _ = p.ShutdownWithContext(context.Background()) +} + +func (p *testContextAwareProvider) BooleanEvaluation(ctx context.Context, flag string, defaultValue bool, flatCtx FlattenedContext) BoolResolutionDetail { + return BoolResolutionDetail{ + Value: defaultValue, + ProviderResolutionDetail: ProviderResolutionDetail{Reason: DefaultReason}, + } +} + +func (p *testContextAwareProvider) StringEvaluation(ctx context.Context, flag string, defaultValue string, flatCtx FlattenedContext) StringResolutionDetail { + return StringResolutionDetail{ + Value: defaultValue, + ProviderResolutionDetail: ProviderResolutionDetail{Reason: DefaultReason}, + } +} + +func (p *testContextAwareProvider) FloatEvaluation(ctx context.Context, flag string, defaultValue float64, flatCtx FlattenedContext) FloatResolutionDetail { + return FloatResolutionDetail{ + Value: defaultValue, + ProviderResolutionDetail: ProviderResolutionDetail{Reason: DefaultReason}, + } +} + +func (p *testContextAwareProvider) IntEvaluation(ctx context.Context, flag string, defaultValue int64, flatCtx FlattenedContext) IntResolutionDetail { + return IntResolutionDetail{ + Value: defaultValue, + ProviderResolutionDetail: ProviderResolutionDetail{Reason: DefaultReason}, + } +} + +func (p *testContextAwareProvider) ObjectEvaluation(ctx context.Context, flag string, defaultValue any, flatCtx FlattenedContext) InterfaceResolutionDetail { + return InterfaceResolutionDetail{ + Value: defaultValue, + ProviderResolutionDetail: ProviderResolutionDetail{Reason: DefaultReason}, + } +} + +func (p *testContextAwareProvider) Hooks() []Hook { + return []Hook{} +} + +func TestContextAwareInitialization(t *testing.T) { + // Save original state + originalAPI := api + originalEventing := eventing + defer func() { + api = originalAPI + eventing = originalEventing + }() + + // Create fresh API for isolated testing + exec := newEventExecutor() + testAPI := newEvaluationAPI(exec) + api = testAPI + eventing = exec + + t.Run("fast provider succeeds within timeout", func(t *testing.T) { + fastProvider := &testContextAwareProvider{initDelay: 50 * time.Millisecond} + + ctx, cancel := context.WithTimeout(context.Background(), 500*time.Millisecond) + defer cancel() + + err := SetProviderWithContextAndWait(ctx, fastProvider) + if err != nil { + t.Errorf("Expected fast provider to succeed, got error: %v", err) + } + }) + + t.Run("slow provider times out", func(t *testing.T) { + slowProvider := &testContextAwareProvider{initDelay: 800 * time.Millisecond} + + ctx, cancel := context.WithTimeout(context.Background(), 200*time.Millisecond) + defer cancel() + + err := SetProviderWithContextAndWait(ctx, slowProvider) + if err == nil { + t.Error("Expected timeout error but got success") + } + if !errors.Is(err, context.DeadlineExceeded) { + t.Errorf("Expected context deadline exceeded, got: %v", err) + } + }) + + t.Run("async initialization returns immediately", func(t *testing.T) { + asyncProvider := &testContextAwareProvider{initDelay: 200 * time.Millisecond} + + ctx, cancel := context.WithTimeout(context.Background(), time.Second) + defer cancel() + + start := time.Now() + err := SetProviderWithContext(ctx, asyncProvider) + elapsed := time.Since(start) + + if err != nil { + t.Errorf("Async setup should not fail: %v", err) + } + if elapsed > 100*time.Millisecond { + t.Errorf("Async setup took too long: %v", elapsed) + } + }) + + t.Run("named provider with context works", func(t *testing.T) { + namedProvider := &testContextAwareProvider{initDelay: 50 * time.Millisecond} + + ctx, cancel := context.WithTimeout(context.Background(), 500*time.Millisecond) + defer cancel() + + err := SetNamedProviderWithContextAndWait(ctx, "test-domain", namedProvider) + if err != nil { + t.Errorf("Named provider should succeed: %v", err) + } + }) + + t.Run("backward compatibility with regular provider", func(t *testing.T) { + legacyProvider := &NoopProvider{} + + ctx, cancel := context.WithTimeout(context.Background(), 500*time.Millisecond) + defer cancel() + + err := SetProviderWithContextAndWait(ctx, legacyProvider) + if err != nil { + t.Errorf("Legacy provider should work: %v", err) + } + }) +} + +func TestContextAwareStateHandlerDetection(t *testing.T) { + // Test that the initializerWithContext function correctly detects ContextAwareStateHandler + evalCtx := EvaluationContext{} + + t.Run("detects ContextAwareStateHandler", func(t *testing.T) { + provider := &testContextAwareProvider{initDelay: 50 * time.Millisecond} + + ctx, cancel := context.WithTimeout(context.Background(), 300*time.Millisecond) + defer cancel() + + event, err := initializerWithContext(ctx, provider, evalCtx) + if err != nil { + t.Errorf("Context-aware provider should initialize successfully: %v", err) + } + if event.EventType != ProviderReady { + t.Errorf("Expected ProviderReady event, got: %v", event.EventType) + } + }) + + t.Run("falls back to regular StateHandler", func(t *testing.T) { + provider := &NoopProvider{} + + ctx, cancel := context.WithTimeout(context.Background(), 300*time.Millisecond) + defer cancel() + + event, err := initializerWithContext(ctx, provider, evalCtx) + if err != nil { + t.Errorf("Regular provider should initialize successfully: %v", err) + } + if event.EventType != ProviderReady { + t.Errorf("Expected ProviderReady event, got: %v", event.EventType) + } + }) + + t.Run("handles timeout in context-aware provider", func(t *testing.T) { + provider := &testContextAwareProvider{initDelay: 500 * time.Millisecond} + + ctx, cancel := context.WithTimeout(context.Background(), 100*time.Millisecond) + defer cancel() + + event, err := initializerWithContext(ctx, provider, evalCtx) + if err == nil { + t.Error("Expected timeout error") + } + if !errors.Is(err, context.DeadlineExceeded) { + t.Errorf("Expected deadline exceeded, got: %v", err) + } + if event.EventType != ProviderError { + t.Errorf("Expected ProviderError event, got: %v", event.EventType) + } + }) +} + +func TestContextAwareShutdown(t *testing.T) { + // Save original state + originalAPI := api + originalEventing := eventing + defer func() { + api = originalAPI + eventing = originalEventing + }() + + // Create fresh API for isolated testing + exec := newEventExecutor() + testAPI := newEvaluationAPI(exec) + api = testAPI + eventing = exec + + t.Run("context-aware shutdown with timeout", func(t *testing.T) { + provider := &testContextAwareProvider{initDelay: 50 * time.Millisecond} + + // Set the provider first + ctx, cancel := context.WithTimeout(context.Background(), 300*time.Millisecond) + defer cancel() + + err := SetProviderWithContextAndWait(ctx, provider) + if err != nil { + t.Errorf("Provider setup should succeed: %v", err) + } + + // Now replace it to trigger shutdown + newProvider := &testContextAwareProvider{initDelay: 10 * time.Millisecond} + err = SetProviderWithContextAndWait(ctx, newProvider) + if err != nil { + t.Errorf("Provider replacement should succeed: %v", err) + } + }) + + t.Run("shutdown timeout handling", func(t *testing.T) { + // Create a provider with long shutdown delay that would timeout during shutdown (not init) + slowShutdownProvider := &testContextAwareProvider{initDelay: 10 * time.Millisecond} // Fast init + + // Set the provider first with generous timeout + ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) + defer cancel() + + err := SetProviderWithContextAndWait(ctx, slowShutdownProvider) + if err != nil { + t.Errorf("Provider setup should succeed: %v", err) + } + + // Replace with new provider - shutdown happens in background, so this should succeed + // even if the old provider takes a long time to shut down + fastProvider := &testContextAwareProvider{initDelay: 10 * time.Millisecond} + err = SetProviderWithContextAndWait(ctx, fastProvider) + if err != nil { + t.Errorf("Provider replacement should succeed even with slow shutdown: %v", err) + } + + // Wait a bit to let any background shutdown complete + time.Sleep(100 * time.Millisecond) + }) +} + +func TestGlobalContextAwareShutdown(t *testing.T) { + // Save original state + originalAPI := api + originalEventing := eventing + defer func() { + api = originalAPI + eventing = originalEventing + }() + + t.Run("shutdown with context affects all providers", func(t *testing.T) { + // Create fresh API for isolated testing + exec := newEventExecutor() + testAPI := newEvaluationAPI(exec) + api = testAPI + eventing = exec + + // Set up multiple providers + defaultProvider := &testContextAwareProvider{initDelay: 50 * time.Millisecond} + namedProvider := &testContextAwareProvider{initDelay: 50 * time.Millisecond} + + ctx, cancel := context.WithTimeout(context.Background(), 2*time.Second) + defer cancel() + + // Set default provider + err := SetProviderWithContextAndWait(ctx, defaultProvider) + if err != nil { + t.Errorf("Default provider setup should succeed: %v", err) + } + + // Set named provider + err = SetNamedProviderWithContextAndWait(ctx, "test-service", namedProvider) + if err != nil { + t.Errorf("Named provider setup should succeed: %v", err) + } + + // Shutdown all providers with context + shutdownCtx, shutdownCancel := context.WithTimeout(context.Background(), 500*time.Millisecond) + defer shutdownCancel() + + err = ShutdownWithContext(shutdownCtx) + if err != nil { + t.Errorf("Global shutdown should succeed: %v", err) + } + }) + + t.Run("shutdown timeout handling", func(t *testing.T) { + // Create fresh API for isolated testing + exec := newEventExecutor() + testAPI := newEvaluationAPI(exec) + api = testAPI + eventing = exec + + // Set up a provider with fast init but simulates long shutdown delay + slowShutdownProvider := &testContextAwareProvider{initDelay: 50 * time.Millisecond} // Fast init + + ctx, cancel := context.WithTimeout(context.Background(), 2*time.Second) + defer cancel() + + // Set the provider (this should succeed quickly) + err := SetProviderWithContextAndWait(ctx, slowShutdownProvider) + if err != nil { + t.Errorf("Provider setup should succeed: %v", err) + } + + // Create a provider that uses the initDelay for shutdown simulation too + // When shutdown is called, it will use the same delay, which would be longer than our timeout + // For this test, we'll create a new provider instance with a longer delay to simulate slow shutdown + testAPI.mu.Lock() + // Replace the provider's delay to simulate slow shutdown + if contextProvider, ok := testAPI.defaultProvider.(*testContextAwareProvider); ok { + contextProvider.initDelay = 5 * time.Second // This will be used by ShutdownWithContext + } + testAPI.mu.Unlock() + + // Try to shutdown with short timeout - this should timeout + shutdownCtx, shutdownCancel := context.WithTimeout(context.Background(), 100*time.Millisecond) + defer shutdownCancel() + + err = ShutdownWithContext(shutdownCtx) + if err == nil { + t.Error("Expected shutdown timeout error") + } + if !errors.Is(err, context.DeadlineExceeded) { + t.Errorf("Expected context deadline exceeded, got: %v", err) + } + }) + + t.Run("backward compatibility with regular providers", func(t *testing.T) { + // Create fresh API for isolated testing + exec := newEventExecutor() + testAPI := newEvaluationAPI(exec) + api = testAPI + eventing = exec + + // Set up regular (non-context-aware) providers + defaultProvider := &NoopProvider{} + namedProvider := &NoopProvider{} + + ctx, cancel := context.WithTimeout(context.Background(), 2*time.Second) + defer cancel() + + // Set providers + err := SetProviderWithContextAndWait(ctx, defaultProvider) + if err != nil { + t.Errorf("Default provider setup should succeed: %v", err) + } + + err = SetNamedProviderWithContextAndWait(ctx, "test-service", namedProvider) + if err != nil { + t.Errorf("Named provider setup should succeed: %v", err) + } + + // Shutdown should work even with non-context-aware providers + shutdownCtx, shutdownCancel := context.WithTimeout(context.Background(), 500*time.Millisecond) + defer shutdownCancel() + + err = ShutdownWithContext(shutdownCtx) + if err != nil { + t.Errorf("Global shutdown should succeed with regular providers: %v", err) + } + }) +} + +// testContextAwareProviderWithShutdownDelay allows different delays for init and shutdown +type testContextAwareProviderWithShutdownDelay struct { + initDelay time.Duration + shutdownDelay time.Duration +} + +func (p *testContextAwareProviderWithShutdownDelay) Metadata() Metadata { + return Metadata{Name: "test-shutdown-delay-provider"} +} + +func (p *testContextAwareProviderWithShutdownDelay) InitWithContext(ctx context.Context, evalCtx EvaluationContext) error { + select { + case <-time.After(p.initDelay): + return nil + case <-ctx.Done(): + return ctx.Err() + } +} + +func (p *testContextAwareProviderWithShutdownDelay) Init(evalCtx EvaluationContext) error { + return p.InitWithContext(context.Background(), evalCtx) +} + +func (p *testContextAwareProviderWithShutdownDelay) ShutdownWithContext(ctx context.Context) error { + select { + case <-time.After(p.shutdownDelay): + return nil + case <-ctx.Done(): + return ctx.Err() + } +} + +func (p *testContextAwareProviderWithShutdownDelay) Shutdown() { + _ = p.ShutdownWithContext(context.Background()) +} + +func (p *testContextAwareProviderWithShutdownDelay) BooleanEvaluation(ctx context.Context, flag string, defaultValue bool, flatCtx FlattenedContext) BoolResolutionDetail { + return BoolResolutionDetail{ + Value: defaultValue, + ProviderResolutionDetail: ProviderResolutionDetail{Reason: DefaultReason}, + } +} + +func (p *testContextAwareProviderWithShutdownDelay) StringEvaluation(ctx context.Context, flag string, defaultValue string, flatCtx FlattenedContext) StringResolutionDetail { + return StringResolutionDetail{ + Value: defaultValue, + ProviderResolutionDetail: ProviderResolutionDetail{Reason: DefaultReason}, + } +} + +func (p *testContextAwareProviderWithShutdownDelay) FloatEvaluation(ctx context.Context, flag string, defaultValue float64, flatCtx FlattenedContext) FloatResolutionDetail { + return FloatResolutionDetail{ + Value: defaultValue, + ProviderResolutionDetail: ProviderResolutionDetail{Reason: DefaultReason}, + } +} + +func (p *testContextAwareProviderWithShutdownDelay) IntEvaluation(ctx context.Context, flag string, defaultValue int64, flatCtx FlattenedContext) IntResolutionDetail { + return IntResolutionDetail{ + Value: defaultValue, + ProviderResolutionDetail: ProviderResolutionDetail{Reason: DefaultReason}, + } +} + +func (p *testContextAwareProviderWithShutdownDelay) ObjectEvaluation(ctx context.Context, flag string, defaultValue any, flatCtx FlattenedContext) InterfaceResolutionDetail { + return InterfaceResolutionDetail{ + Value: defaultValue, + ProviderResolutionDetail: ProviderResolutionDetail{Reason: DefaultReason}, + } +} + +func (p *testContextAwareProviderWithShutdownDelay) Hooks() []Hook { + return []Hook{} +} + +func TestContextPropagationFixes(t *testing.T) { + // Save original state + originalAPI := api + originalEventing := eventing + defer func() { + api = originalAPI + eventing = originalEventing + }() + + // Create fresh API for isolated testing + exec := newEventExecutor() + testAPI := newEvaluationAPI(exec) + api = testAPI + eventing = exec + + t.Run("shutdown uses passed context timeout", func(t *testing.T) { + // Create provider with fast init but slow shutdown + provider := &testContextAwareProviderWithShutdownDelay{ + initDelay: 10 * time.Millisecond, // Fast init + shutdownDelay: 500 * time.Millisecond, // Slow shutdown + } + + // Set provider with long timeout - should succeed + initCtx, cancel := context.WithTimeout(context.Background(), 2*time.Second) + defer cancel() + + err := SetProviderWithContextAndWait(initCtx, provider) + if err != nil { + t.Errorf("Provider setup should succeed: %v", err) + } + + // Replace provider with short timeout - shutdown should respect the timeout + newProvider := &testContextAwareProvider{initDelay: 10 * time.Millisecond} + + // Use a short timeout that's shorter than the shutdown delay + replaceCtx, replaceCancel := context.WithTimeout(context.Background(), 200*time.Millisecond) + defer replaceCancel() + + start := time.Now() + err = SetProviderWithContextAndWait(replaceCtx, newProvider) + elapsed := time.Since(start) + + // The init should succeed quickly, shutdown happens async + if err != nil { + t.Errorf("Provider replacement should succeed: %v", err) + } + + // Should complete quickly since init is fast and shutdown is async + if elapsed > 100*time.Millisecond { + t.Errorf("Provider replacement took too long: %v (expected < 100ms)", elapsed) + } + + // Wait a bit to let shutdown complete + time.Sleep(100 * time.Millisecond) + }) + + t.Run("shutdown respects context cancellation", func(t *testing.T) { + // Reset API + exec = newEventExecutor() + testAPI = newEvaluationAPI(exec) + api = testAPI + eventing = exec + + provider := &testContextAwareProviderWithShutdownDelay{ + initDelay: 10 * time.Millisecond, + shutdownDelay: 5 * time.Second, // Very slow shutdown + } + + // Set up provider + err := SetProviderWithContextAndWait(context.Background(), provider) + if err != nil { + t.Errorf("Provider setup should succeed: %v", err) + } + + // Create a context that we'll cancel quickly + replaceCtx, cancel := context.WithCancel(context.Background()) + + // Start provider replacement + go func() { + time.Sleep(50 * time.Millisecond) + cancel() // Cancel context during operation + }() + + newProvider := &testContextAwareProvider{initDelay: 10 * time.Millisecond} + err = SetProviderWithContextAndWait(replaceCtx, newProvider) + // Should succeed because init is fast, shutdown is async + if err != nil { + t.Errorf("Provider replacement should succeed even with cancellation: %v", err) + } + }) +} + +func TestSimplifiedErrorHandling(t *testing.T) { + evalCtx := EvaluationContext{} + + t.Run("context cancellation error message", func(t *testing.T) { + provider := &testContextAwareProvider{initDelay: 200 * time.Millisecond} + + ctx, cancel := context.WithCancel(context.Background()) + cancel() // Cancel immediately + + event, err := initializerWithContext(ctx, provider, evalCtx) + if err == nil { + t.Error("Expected error for cancelled context") + } + if !errors.Is(err, context.Canceled) { + t.Errorf("Expected context.Canceled error, got: %v", err) + } + if event.EventType != ProviderError { + t.Errorf("Expected ProviderError event, got: %v", event.EventType) + } + if event.Message != "Provider initialization cancelled" { + t.Errorf("Expected cancellation message, got: %q", event.Message) + } + }) + + t.Run("context timeout error message", func(t *testing.T) { + provider := &testContextAwareProvider{initDelay: 200 * time.Millisecond} + + ctx, cancel := context.WithTimeout(context.Background(), 50*time.Millisecond) + defer cancel() + + event, err := initializerWithContext(ctx, provider, evalCtx) + if err == nil { + t.Error("Expected timeout error") + } + if !errors.Is(err, context.DeadlineExceeded) { + t.Errorf("Expected context.DeadlineExceeded error, got: %v", err) + } + if event.EventType != ProviderError { + t.Errorf("Expected ProviderError event, got: %v", event.EventType) + } + if event.Message != "Provider initialization timed out" { + t.Errorf("Expected timeout message, got: %q", event.Message) + } + }) + + t.Run("provider init error takes precedence", func(t *testing.T) { + // Create a provider that returns a ProviderInitError even with context issues + provider := &testProviderInitError{ + initDelay: 50 * time.Millisecond, + initError: &ProviderInitError{ + ErrorCode: ProviderFatalCode, + Message: "Custom provider error", + }, + } + + ctx, cancel := context.WithTimeout(context.Background(), 200*time.Millisecond) // Longer than init + defer cancel() + + event, err := initializerWithContext(ctx, provider, evalCtx) + if err == nil { + t.Error("Expected provider init error") + } + + // Should get the custom provider error, not a context error + if event.EventType != ProviderError { + t.Errorf("Expected ProviderError event, got: %v", event.EventType) + } + if event.ErrorCode != ProviderFatalCode { + t.Errorf("Expected ProviderFatalCode, got: %v", event.ErrorCode) + } + if event.Message != "Custom provider error" { + t.Errorf("Expected custom error message, got: %q", event.Message) + } + }) +} + +// testProviderInitError is a provider that returns a specific ProviderInitError +type testProviderInitError struct { + initDelay time.Duration + initError *ProviderInitError +} + +func (p *testProviderInitError) Metadata() Metadata { + return Metadata{Name: "test-provider-init-error"} +} + +func (p *testProviderInitError) InitWithContext(ctx context.Context, evalCtx EvaluationContext) error { + select { + case <-time.After(p.initDelay): + return p.initError + case <-ctx.Done(): + // Still return the provider error even if context is cancelled + return p.initError + } +} + +func (p *testProviderInitError) Init(evalCtx EvaluationContext) error { + return p.InitWithContext(context.Background(), evalCtx) +} + +func (p *testProviderInitError) ShutdownWithContext(ctx context.Context) error { + return nil +} + +func (p *testProviderInitError) Shutdown() {} + +func (p *testProviderInitError) BooleanEvaluation(ctx context.Context, flag string, defaultValue bool, flatCtx FlattenedContext) BoolResolutionDetail { + return BoolResolutionDetail{ + Value: defaultValue, + ProviderResolutionDetail: ProviderResolutionDetail{Reason: DefaultReason}, + } +} + +func (p *testProviderInitError) StringEvaluation(ctx context.Context, flag string, defaultValue string, flatCtx FlattenedContext) StringResolutionDetail { + return StringResolutionDetail{ + Value: defaultValue, + ProviderResolutionDetail: ProviderResolutionDetail{Reason: DefaultReason}, + } +} + +func (p *testProviderInitError) FloatEvaluation(ctx context.Context, flag string, defaultValue float64, flatCtx FlattenedContext) FloatResolutionDetail { + return FloatResolutionDetail{ + Value: defaultValue, + ProviderResolutionDetail: ProviderResolutionDetail{Reason: DefaultReason}, + } +} + +func (p *testProviderInitError) IntEvaluation(ctx context.Context, flag string, defaultValue int64, flatCtx FlattenedContext) IntResolutionDetail { + return IntResolutionDetail{ + Value: defaultValue, + ProviderResolutionDetail: ProviderResolutionDetail{Reason: DefaultReason}, + } +} + +func (p *testProviderInitError) ObjectEvaluation(ctx context.Context, flag string, defaultValue any, flatCtx FlattenedContext) InterfaceResolutionDetail { + return InterfaceResolutionDetail{ + Value: defaultValue, + ProviderResolutionDetail: ProviderResolutionDetail{Reason: DefaultReason}, + } +} + +func (p *testProviderInitError) Hooks() []Hook { + return []Hook{} +} + +func TestEdgeCases(t *testing.T) { + // Save original state + originalAPI := api + originalEventing := eventing + defer func() { + api = originalAPI + eventing = originalEventing + }() + + // Create fresh API for isolated testing + exec := newEventExecutor() + testAPI := newEvaluationAPI(exec) + api = testAPI + eventing = exec + + t.Run("rapid provider switching", func(t *testing.T) { + // Reset API + exec = newEventExecutor() + testAPI = newEvaluationAPI(exec) + api = testAPI + eventing = exec + + providers := []*testContextAwareProvider{ + {initDelay: 10 * time.Millisecond}, + {initDelay: 15 * time.Millisecond}, + {initDelay: 5 * time.Millisecond}, + } + + // Rapidly switch providers + for i, provider := range providers { + ctx, cancel := context.WithTimeout(context.Background(), 100*time.Millisecond) + err := SetProviderWithContextAndWait(ctx, provider) + cancel() + + if err != nil { + t.Errorf("Provider %d setup should succeed: %v", i, err) + } + } + + // Let any pending shutdowns complete + time.Sleep(200 * time.Millisecond) + }) + + t.Run("concurrent operations with different contexts", func(t *testing.T) { + // Reset API + exec = newEventExecutor() + testAPI = newEvaluationAPI(exec) + api = testAPI + eventing = exec + + // Use channels to coordinate goroutines + done := make(chan error, 2) + + // Start two concurrent provider operations + go func() { + ctx, cancel := context.WithTimeout(context.Background(), 150*time.Millisecond) + defer cancel() + + provider := &testContextAwareProvider{initDelay: 50 * time.Millisecond} + err := SetProviderWithContextAndWait(ctx, provider) + done <- err + }() + + go func() { + time.Sleep(25 * time.Millisecond) // Start slightly later + ctx, cancel := context.WithTimeout(context.Background(), 100*time.Millisecond) + defer cancel() + + provider := &testContextAwareProvider{initDelay: 30 * time.Millisecond} + err := SetNamedProviderWithContextAndWait(ctx, "concurrent-test", provider) + done <- err + }() + + // Wait for both to complete + for i := range 2 { + if err := <-done; err != nil { + t.Errorf("Concurrent operation %d failed: %v", i, err) + } + } + }) +} diff --git a/openfeature/example_context_test.go b/openfeature/example_context_test.go new file mode 100644 index 00000000..2652dee9 --- /dev/null +++ b/openfeature/example_context_test.go @@ -0,0 +1,167 @@ +package openfeature_test + +import ( + "context" + "fmt" + "log" + "time" + + "github.com/open-feature/go-sdk/openfeature" +) + +// ExampleSetProviderWithContext demonstrates asynchronous provider setup with timeout control. +func ExampleSetProviderWithContext() { + // Create a test provider for demonstration + provider := &openfeature.NoopProvider{} + + ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second) + defer cancel() + + err := openfeature.SetProviderWithContext(ctx, provider) + if err != nil { + log.Printf("Failed to start provider setup: %v", err) + return + } + + // Provider continues initializing in background + fmt.Println("Provider setup initiated") + // Output: Provider setup initiated +} + +// ExampleSetProviderWithContextAndWait demonstrates synchronous provider setup with error handling. +func ExampleSetProviderWithContextAndWait() { + // Create a test provider for demonstration + provider := &openfeature.NoopProvider{} + + ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second) + defer cancel() + + err := openfeature.SetProviderWithContextAndWait(ctx, provider) + if err != nil { + log.Printf("Provider initialization failed: %v", err) + return + } + + // Provider is now ready to use + fmt.Println("Provider is ready") + // Output: Provider is ready +} + +// ExampleSetNamedProviderWithContext demonstrates multi-tenant provider setup. +func ExampleSetNamedProviderWithContext() { + // Create test providers for different services + userProvider := &openfeature.NoopProvider{} + billingProvider := &openfeature.NoopProvider{} + + ctx, cancel := context.WithTimeout(context.Background(), 15*time.Second) + defer cancel() + + err := openfeature.SetNamedProviderWithContext(ctx, "user-service", userProvider) + if err != nil { + log.Printf("Failed to setup user service provider: %v", err) + return + } + + err = openfeature.SetNamedProviderWithContext(ctx, "billing-service", billingProvider) + if err != nil { + log.Printf("Failed to setup billing service provider: %v", err) + return + } + + // Create clients for different domains + userClient := openfeature.NewClient("user-service") + billingClient := openfeature.NewClient("billing-service") + + fmt.Printf("User client domain: %s\n", userClient.Metadata().Domain()) + fmt.Printf("Billing client domain: %s\n", billingClient.Metadata().Domain()) + // Output: User client domain: user-service + // Billing client domain: billing-service +} + +// ExampleSetNamedProviderWithContextAndWait demonstrates critical service provider setup. +func ExampleSetNamedProviderWithContextAndWait() { + // Create a test provider for demonstration + criticalProvider := &openfeature.NoopProvider{} + + ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second) + defer cancel() + + // Wait for critical providers to be ready + err := openfeature.SetNamedProviderWithContextAndWait(ctx, "critical-service", criticalProvider) + if err != nil { + log.Printf("Critical provider failed to initialize: %v", err) + return + } + + // Now safe to use the client + client := openfeature.NewClient("critical-service") + enabled, _ := client.BooleanValue(context.Background(), "feature-x", false, openfeature.EvaluationContext{}) + + fmt.Printf("Critical service ready, feature-x enabled: %v\n", enabled) + // Output: Critical service ready, feature-x enabled: false +} + +// ExampleContextAwareStateHandler demonstrates how context-aware shutdown works automatically. +func ExampleContextAwareStateHandler() { + // Context-aware providers automatically use ShutdownWithContext when replaced + provider1 := &openfeature.NoopProvider{} + provider2 := &openfeature.NoopProvider{} + + ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) + defer cancel() + + // Set first provider + err := openfeature.SetProviderWithContextAndWait(ctx, provider1) + if err != nil { + log.Printf("Provider setup failed: %v", err) + return + } + + // Replace with second provider - this triggers context-aware shutdown of provider1 if it supports it + err = openfeature.SetProviderWithContextAndWait(ctx, provider2) + if err != nil { + log.Printf("Provider replacement failed: %v", err) + return + } + + fmt.Println("Context-aware provider lifecycle completed") + // Output: Context-aware provider lifecycle completed +} + +// ExampleShutdownWithContext demonstrates graceful application shutdown with timeout control. +func ExampleShutdownWithContext() { + // Set up providers + provider1 := &openfeature.NoopProvider{} + provider2 := &openfeature.NoopProvider{} + + ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) + defer cancel() + + // Set up multiple providers + err := openfeature.SetProviderWithContextAndWait(ctx, provider1) + if err != nil { + log.Printf("Provider setup failed: %v", err) + return + } + + err = openfeature.SetNamedProviderWithContextAndWait(ctx, "service-a", provider2) + if err != nil { + log.Printf("Named provider setup failed: %v", err) + return + } + + // Application is running... + + // When application is shutting down, use context-aware shutdown + shutdownCtx, shutdownCancel := context.WithTimeout(context.Background(), 30*time.Second) + defer shutdownCancel() + + err = openfeature.ShutdownWithContext(shutdownCtx) + if err != nil { + log.Printf("Shutdown completed with errors: %v", err) + } else { + fmt.Println("All providers shut down successfully") + } + + // Output: All providers shut down successfully +} diff --git a/openfeature/interfaces.go b/openfeature/interfaces.go index d0ddf045..ad429e5b 100644 --- a/openfeature/interfaces.go +++ b/openfeature/interfaces.go @@ -18,6 +18,7 @@ type IEvaluation interface { SetEvaluationContext(evalCtx EvaluationContext) AddHooks(hooks ...Hook) Shutdown() + ShutdownWithContext(ctx context.Context) error IEventing } @@ -67,6 +68,12 @@ type evaluationImpl interface { SetLogger(l logr.Logger) ForEvaluation(clientName string) (FeatureProvider, []Hook, EvaluationContext) + + // Context-aware provider setup methods + SetProviderWithContext(ctx context.Context, provider FeatureProvider) error + SetProviderWithContextAndWait(ctx context.Context, provider FeatureProvider) error + SetNamedProviderWithContext(ctx context.Context, clientName string, provider FeatureProvider, async bool) error + SetNamedProviderWithContextAndWait(ctx context.Context, clientName string, provider FeatureProvider) error } // eventingImpl is an internal reference interface extending IEventing diff --git a/openfeature/interfaces_mock.go b/openfeature/interfaces_mock.go index 54c96d73..202f9aa3 100644 --- a/openfeature/interfaces_mock.go +++ b/openfeature/interfaces_mock.go @@ -205,6 +205,20 @@ func (mr *MockIEvaluationMockRecorder) Shutdown() *gomock.Call { return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Shutdown", reflect.TypeOf((*MockIEvaluation)(nil).Shutdown)) } +// ShutdownWithContext mocks base method. +func (m *MockIEvaluation) ShutdownWithContext(ctx context.Context) error { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "ShutdownWithContext", ctx) + ret0, _ := ret[0].(error) + return ret0 +} + +// ShutdownWithContext indicates an expected call of ShutdownWithContext. +func (mr *MockIEvaluationMockRecorder) ShutdownWithContext(ctx any) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "ShutdownWithContext", reflect.TypeOf((*MockIEvaluation)(nil).ShutdownWithContext), ctx) +} + // MockIClient is a mock of IClient interface. type MockIClient struct { ctrl *gomock.Controller @@ -934,6 +948,76 @@ func (mr *MockevaluationImplMockRecorder) Shutdown() *gomock.Call { return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Shutdown", reflect.TypeOf((*MockevaluationImpl)(nil).Shutdown)) } +// ShutdownWithContext mocks base method. +func (m *MockevaluationImpl) ShutdownWithContext(ctx context.Context) error { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "ShutdownWithContext", ctx) + ret0, _ := ret[0].(error) + return ret0 +} + +// ShutdownWithContext indicates an expected call of ShutdownWithContext. +func (mr *MockevaluationImplMockRecorder) ShutdownWithContext(ctx any) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "ShutdownWithContext", reflect.TypeOf((*MockevaluationImpl)(nil).ShutdownWithContext), ctx) +} + +// SetProviderWithContext mocks base method. +func (m *MockevaluationImpl) SetProviderWithContext(ctx context.Context, provider FeatureProvider) error { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "SetProviderWithContext", ctx, provider) + ret0, _ := ret[0].(error) + return ret0 +} + +// SetProviderWithContext indicates an expected call of SetProviderWithContext. +func (mr *MockevaluationImplMockRecorder) SetProviderWithContext(ctx, provider any) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "SetProviderWithContext", reflect.TypeOf((*MockevaluationImpl)(nil).SetProviderWithContext), ctx, provider) +} + +// SetProviderWithContextAndWait mocks base method. +func (m *MockevaluationImpl) SetProviderWithContextAndWait(ctx context.Context, provider FeatureProvider) error { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "SetProviderWithContextAndWait", ctx, provider) + ret0, _ := ret[0].(error) + return ret0 +} + +// SetProviderWithContextAndWait indicates an expected call of SetProviderWithContextAndWait. +func (mr *MockevaluationImplMockRecorder) SetProviderWithContextAndWait(ctx, provider any) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "SetProviderWithContextAndWait", reflect.TypeOf((*MockevaluationImpl)(nil).SetProviderWithContextAndWait), ctx, provider) +} + +// SetNamedProviderWithContext mocks base method. +func (m *MockevaluationImpl) SetNamedProviderWithContext(ctx context.Context, clientName string, provider FeatureProvider, async bool) error { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "SetNamedProviderWithContext", ctx, clientName, provider, async) + ret0, _ := ret[0].(error) + return ret0 +} + +// SetNamedProviderWithContext indicates an expected call of SetNamedProviderWithContext. +func (mr *MockevaluationImplMockRecorder) SetNamedProviderWithContext(ctx, clientName, provider, async any) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "SetNamedProviderWithContext", reflect.TypeOf((*MockevaluationImpl)(nil).SetNamedProviderWithContext), ctx, clientName, provider, async) +} + +// SetNamedProviderWithContextAndWait mocks base method. +func (m *MockevaluationImpl) SetNamedProviderWithContextAndWait(ctx context.Context, clientName string, provider FeatureProvider) error { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "SetNamedProviderWithContextAndWait", ctx, clientName, provider) + ret0, _ := ret[0].(error) + return ret0 +} + +// SetNamedProviderWithContextAndWait indicates an expected call of SetNamedProviderWithContextAndWait. +func (mr *MockevaluationImplMockRecorder) SetNamedProviderWithContextAndWait(ctx, clientName, provider any) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "SetNamedProviderWithContextAndWait", reflect.TypeOf((*MockevaluationImpl)(nil).SetNamedProviderWithContextAndWait), ctx, clientName, provider) +} + // MockeventingImpl is a mock of eventingImpl interface. type MockeventingImpl struct { ctrl *gomock.Controller diff --git a/openfeature/openfeature.go b/openfeature/openfeature.go index 43e9ff05..69db3f29 100644 --- a/openfeature/openfeature.go +++ b/openfeature/openfeature.go @@ -1,6 +1,10 @@ package openfeature -import "github.com/go-logr/logr" +import ( + "context" + + "github.com/go-logr/logr" +) // api is the global evaluationImpl implementation. This is a singleton and there can only be one instance. var ( @@ -47,6 +51,30 @@ func SetProviderAndWait(provider FeatureProvider) error { return api.SetProviderAndWait(provider) } +// SetProviderWithContext sets the default [FeatureProvider] with context-aware initialization. +// If the provider implements ContextAwareStateHandler, InitWithContext will be called with the provided context. +// Provider initialization is asynchronous and status can be checked from provider status. +// Returns an error immediately if provider is nil, or if context is cancelled during setup. +// +// Use this function for non-blocking provider setup with timeout control where you want +// to continue application startup while the provider initializes in background. +// For providers that don't implement ContextAwareStateHandler, this behaves +// identically to SetProvider() but with timeout protection. +func SetProviderWithContext(ctx context.Context, provider FeatureProvider) error { + return api.SetProviderWithContext(ctx, provider) +} + +// SetProviderWithContextAndWait sets the default [FeatureProvider] with context-aware initialization and waits for completion. +// If the provider implements ContextAwareStateHandler, InitWithContext will be called with the provided context. +// Returns an error if initialization causes an error, or if context is cancelled during initialization. +// +// Use this function for synchronous provider setup with guaranteed readiness when you need +// application startup to wait for the provider before continuing. +// Recommended timeout values: 1-5s for local providers, 10-30s for network-based providers. +func SetProviderWithContextAndWait(ctx context.Context, provider FeatureProvider) error { + return api.SetProviderWithContextAndWait(ctx, provider) +} + // ProviderMetadata returns the default [FeatureProvider] metadata func ProviderMetadata() Metadata { return api.GetProviderMetadata() @@ -64,6 +92,27 @@ func SetNamedProviderAndWait(domain string, provider FeatureProvider) error { return api.SetNamedProvider(domain, provider, false) } +// SetNamedProviderWithContext sets a [FeatureProvider] mapped to the given [Client] domain with context-aware initialization. +// If the provider implements ContextAwareStateHandler, InitWithContext will be called with the provided context. +// Provider initialization is asynchronous and status can be checked from provider status. +// Returns an error immediately if provider is nil, or if context is cancelled during setup. +// +// Named providers allow different domains to use different feature flag providers, +// enabling multi-tenant applications or microservice architectures. +func SetNamedProviderWithContext(ctx context.Context, domain string, provider FeatureProvider) error { + return api.SetNamedProviderWithContext(ctx, domain, provider, true) +} + +// SetNamedProviderWithContextAndWait sets a provider mapped to the given [Client] domain with context-aware initialization and waits for completion. +// If the provider implements ContextAwareStateHandler, InitWithContext will be called with the provided context. +// Returns an error if initialization causes an error, or if context is cancelled during initialization. +// +// Use this for synchronous named provider setup where you need to ensure +// the provider is ready before proceeding. +func SetNamedProviderWithContextAndWait(ctx context.Context, domain string, provider FeatureProvider) error { + return api.SetNamedProviderWithContextAndWait(ctx, domain, provider) +} + // NamedProviderMetadata returns the named provider's Metadata func NamedProviderMetadata(name string) Metadata { return api.GetNamedProviderMetadata(name) @@ -102,3 +151,14 @@ func Shutdown() { api.Shutdown() initSingleton() } + +// ShutdownWithContext calls context-aware shutdown on all registered providers. +// If providers implement ContextAwareStateHandler, ShutdownWithContext will be called with the provided context. +// It resets the state of the API, removing all hooks, event handlers, and providers. +// This is intended to be called when your application is terminating. +// Returns an error if any provider shutdown fails or if context is cancelled during shutdown. +func ShutdownWithContext(ctx context.Context) error { + err := api.ShutdownWithContext(ctx) + initSingleton() + return err +} diff --git a/openfeature/openfeature_api.go b/openfeature/openfeature_api.go index 653002f0..715a4ca0 100644 --- a/openfeature/openfeature_api.go +++ b/openfeature/openfeature_api.go @@ -1,6 +1,7 @@ package openfeature import ( + "context" "errors" "fmt" "maps" @@ -62,7 +63,7 @@ func (api *evaluationAPI) SetNamedProvider(clientName string, provider FeaturePr oldProvider := api.namedProviders[clientName] api.namedProviders[clientName] = provider - err := api.initNewAndShutdownOld(clientName, provider, oldProvider, async) + err := api.initNewAndShutdownOld(context.Background(), clientName, provider, oldProvider, async) if err != nil { return err } @@ -88,6 +89,124 @@ func (api *evaluationAPI) GetNamedProviderMetadata(name string) Metadata { return provider.Metadata() } +// Context-aware provider setup methods + +// SetProviderWithContext sets the default FeatureProvider with context-aware initialization. +func (api *evaluationAPI) SetProviderWithContext(ctx context.Context, provider FeatureProvider) error { + return api.setProviderWithContext(ctx, provider, true) +} + +// SetProviderWithContextAndWait sets the default FeatureProvider with context-aware initialization and waits for completion. +func (api *evaluationAPI) SetProviderWithContextAndWait(ctx context.Context, provider FeatureProvider) error { + return api.setProviderWithContext(ctx, provider, false) +} + +// setProviderWithContext sets the default FeatureProvider of the evaluationAPI with context-aware initialization. +func (api *evaluationAPI) setProviderWithContext(ctx context.Context, provider FeatureProvider, async bool) error { + api.mu.Lock() + defer api.mu.Unlock() + + if provider == nil { + return errors.New("default provider cannot be set to nil") + } + + oldProvider := api.defaultProvider + api.defaultProvider = provider + + err := api.initNewAndShutdownOld(ctx, "", provider, oldProvider, async) + if err != nil { + return fmt.Errorf("failed to initialize default provider %q: %w", provider.Metadata().Name, err) + } + + err = api.eventExecutor.registerDefaultProvider(provider) + if err != nil { + return fmt.Errorf("failed to register default provider %q: %w", provider.Metadata().Name, err) + } + + return nil +} + +// SetNamedProviderWithContext sets a provider with client name using context-aware initialization. +func (api *evaluationAPI) SetNamedProviderWithContext(ctx context.Context, clientName string, provider FeatureProvider, async bool) error { + api.mu.Lock() + defer api.mu.Unlock() + + if provider == nil { + return errors.New("provider cannot be set to nil") + } + + // Initialize new named provider and Shutdown the old one + oldProvider := api.namedProviders[clientName] + api.namedProviders[clientName] = provider + + err := api.initNewAndShutdownOld(ctx, clientName, provider, oldProvider, async) + if err != nil { + return fmt.Errorf("failed to initialize named provider %q for domain %q: %w", provider.Metadata().Name, clientName, err) + } + + err = api.eventExecutor.registerNamedEventingProvider(clientName, provider) + if err != nil { + return fmt.Errorf("failed to register named provider %q for domain %q: %w", provider.Metadata().Name, clientName, err) + } + + return nil +} + +// SetNamedProviderWithContextAndWait sets a provider with client name using context-aware initialization and waits for completion. +func (api *evaluationAPI) SetNamedProviderWithContextAndWait(ctx context.Context, clientName string, provider FeatureProvider) error { + return api.SetNamedProviderWithContext(ctx, clientName, provider, false) +} + +// initNewAndShutdownOld is the main helper to initialise new FeatureProvider and Shutdown the old FeatureProvider. +// Always uses the context-aware initializer with the provided context. +// +// When shutting down old providers that implement ContextAwareStateHandler, a 10-second timeout +// is applied to prevent hanging if the provider becomes unresponsive during shutdown. +func (api *evaluationAPI) initNewAndShutdownOld(ctx context.Context, clientName string, newProvider FeatureProvider, oldProvider FeatureProvider, async bool) error { + if async { + go func(executor *eventExecutor, evalCtx EvaluationContext, ctx context.Context, provider FeatureProvider, clientName string) { + // for async initialization, error is conveyed as an event + event, _ := initializerWithContext(ctx, provider, evalCtx) + executor.states.Store(clientName, stateFromEventOrError(event, nil)) + executor.triggerEvent(event, provider) + }(api.eventExecutor, api.evalCtx, ctx, newProvider, clientName) + } else { + event, err := initializerWithContext(ctx, newProvider, api.evalCtx) + api.eventExecutor.states.Store(clientName, stateFromEventOrError(event, err)) + api.eventExecutor.triggerEvent(event, newProvider) + if err != nil { + return err + } + } + + v, ok := oldProvider.(StateHandler) + + // oldProvider can be nil or without state handling capability + if oldProvider == nil || !ok { + return nil + } + + namedProviders := slices.Collect(maps.Values(api.namedProviders)) + + // check for multiple bindings + if oldProvider == api.defaultProvider || slices.Contains(namedProviders, oldProvider) { + return nil + } + + go func(forShutdown StateHandler, parentCtx context.Context) { + // Check if the provider supports context-aware shutdown + if contextHandler, ok := forShutdown.(ContextAwareStateHandler); ok { + // Use the provided context directly - user controls timeout + _ = contextHandler.ShutdownWithContext(parentCtx) + } else { + // Fall back to regular shutdown for backward compatibility + forShutdown.Shutdown() + } + }(v, ctx) + + return nil +} + // GetNamedProviders returns named providers map. func (api *evaluationAPI) GetNamedProviders() map[string]FeatureProvider { api.mu.RLock() @@ -142,20 +261,43 @@ func (api *evaluationAPI) RemoveHandler(eventType EventType, callback EventCallb } func (api *evaluationAPI) Shutdown() { + // Use the context-aware shutdown with background context and ignore errors + // to maintain backward compatibility (Shutdown doesn't return an error) + _ = api.ShutdownWithContext(context.Background()) +} + +// ShutdownWithContext calls context-aware shutdown on all registered providers. +// If providers implement ContextAwareStateHandler, ShutdownWithContext will be called with the provided context. +// Returns an error if any provider shutdown fails or if context is cancelled during shutdown. +func (api *evaluationAPI) ShutdownWithContext(ctx context.Context) error { api.mu.Lock() defer api.mu.Unlock() - v, ok := api.defaultProvider.(StateHandler) - if ok { - v.Shutdown() + var errs []error + + // Shutdown default provider + if api.defaultProvider != nil { + if contextHandler, ok := api.defaultProvider.(ContextAwareStateHandler); ok { + if err := contextHandler.ShutdownWithContext(ctx); err != nil { + errs = append(errs, fmt.Errorf("default provider shutdown failed: %w", err)) + } + } else if stateHandler, ok := api.defaultProvider.(StateHandler); ok { + stateHandler.Shutdown() + } } - for _, provider := range api.namedProviders { - v, ok = provider.(StateHandler) - if ok { - v.Shutdown() + // Shutdown all named providers + for name, provider := range api.namedProviders { + if contextHandler, ok := provider.(ContextAwareStateHandler); ok { + if err := contextHandler.ShutdownWithContext(ctx); err != nil { + errs = append(errs, fmt.Errorf("named provider %q shutdown failed: %w", name, err)) + } + } else if stateHandler, ok := provider.(StateHandler); ok { + stateHandler.Shutdown() } } + + return errors.Join(errs...) } // ForEvaluation is a helper to retrieve transaction scoped operators. @@ -195,7 +337,7 @@ func (api *evaluationAPI) setProvider(provider FeatureProvider, async bool) erro oldProvider := api.defaultProvider api.defaultProvider = provider - err := api.initNewAndShutdownOld("", provider, oldProvider, async) + err := api.initNewAndShutdownOld(context.Background(), "", provider, oldProvider, async) if err != nil { return err } @@ -208,48 +350,10 @@ func (api *evaluationAPI) setProvider(provider FeatureProvider, async bool) erro return nil } -// initNewAndShutdownOld is a helper to initialise new FeatureProvider and Shutdown the old FeatureProvider. -func (api *evaluationAPI) initNewAndShutdownOld(clientName string, newProvider FeatureProvider, oldProvider FeatureProvider, async bool) error { - if async { - go func(executor *eventExecutor, ctx EvaluationContext) { - // for async initialization, error is conveyed as an event - event, _ := initializer(newProvider, ctx) - executor.states.Store(clientName, stateFromEventOrError(event, nil)) - executor.triggerEvent(event, newProvider) - }(api.eventExecutor, api.evalCtx) - } else { - event, err := initializer(newProvider, api.evalCtx) - api.eventExecutor.states.Store(clientName, stateFromEventOrError(event, err)) - api.eventExecutor.triggerEvent(event, newProvider) - if err != nil { - return err - } - } - - v, ok := oldProvider.(StateHandler) - - // oldProvider can be nil or without state handling capability - if oldProvider == nil || !ok { - return nil - } - - namedProviders := slices.Collect(maps.Values(api.namedProviders)) - - // check for multiple bindings - if oldProvider == api.defaultProvider || slices.Contains(namedProviders, oldProvider) { - return nil - } - - go func(forShutdown StateHandler) { - forShutdown.Shutdown() - }(v) - - return nil -} - -// initializer is a helper to execute provider initialization and generate appropriate event for the initialization -// It also returns an error if the initialization resulted in an error -func initializer(provider FeatureProvider, evalCtx EvaluationContext) (Event, error) { +// initializerWithContext is a context-aware helper to execute provider initialization and generate appropriate event for the initialization +// If the provider implements ContextAwareStateHandler, InitWithContext is called; otherwise, Init is called for backward compatibility. +// It also returns an error if the initialization resulted in an error or if the context is cancelled. +func initializerWithContext(ctx context.Context, provider FeatureProvider, evalCtx EvaluationContext) (Event, error) { event := Event{ ProviderName: provider.Metadata().Name, EventType: ProviderReady, @@ -258,6 +362,29 @@ func initializer(provider FeatureProvider, evalCtx EvaluationContext) (Event, er }, } + // Check for context-aware handler first + if contextHandler, ok := provider.(ContextAwareStateHandler); ok { + err := contextHandler.InitWithContext(ctx, evalCtx) + if err != nil { + event.EventType = ProviderError + + // Check for specific provider initialization errors first + var initErr *ProviderInitError + if errors.As(err, &initErr) { + event.ErrorCode = initErr.ErrorCode + event.Message = initErr.Message + } else if errors.Is(err, context.Canceled) { + event.Message = "Provider initialization cancelled" + } else if errors.Is(err, context.DeadlineExceeded) { + event.Message = "Provider initialization timed out" + } else { + event.Message = fmt.Sprintf("Provider initialization failed: %v", err) + } + } + return event, err + } + + // Fall back to regular StateHandler for backward compatibility handler, ok := provider.(StateHandler) if !ok { // Note - a provider without state handling capability can be assumed to be ready immediately. @@ -267,14 +394,13 @@ func initializer(provider FeatureProvider, evalCtx EvaluationContext) (Event, er err := handler.Init(evalCtx) if err != nil { event.EventType = ProviderError - event.Message = fmt.Sprintf("Provider initialization error, %v", err) + event.Message = fmt.Sprintf("Provider initialization failed: %v", err) var initErr *ProviderInitError if errors.As(err, &initErr) { event.EventType = ProviderError event.ErrorCode = initErr.ErrorCode event.Message = initErr.Message } - } return event, err diff --git a/openfeature/provider.go b/openfeature/provider.go index f02df0a0..ce5378dc 100644 --- a/openfeature/provider.go +++ b/openfeature/provider.go @@ -66,6 +66,26 @@ type StateHandler interface { Shutdown() } +// ContextAwareStateHandler extends StateHandler with context-aware initialization and shutdown +// for providers that need to respect request timeouts and cancellation. +// If a provider implements this interface, InitWithContext and ShutdownWithContext will be called instead of Init and Shutdown. +// +// Use this interface when your provider needs to: +// - Respect initialization/shutdown timeouts (e.g., network calls, database connections) +// - Support graceful cancellation during setup and teardown +// - Integrate with request-scoped contexts +// +// Best practices: +// - Always check ctx.Done() in long-running initialization and shutdown operations +// - Use reasonable timeout values (typically 5-30 seconds) +// - Return ctx.Err() when the context is cancelled +// - Maintain backward compatibility by implementing both interfaces +type ContextAwareStateHandler interface { + StateHandler // Embed existing interface for backward compatibility + InitWithContext(ctx context.Context, evaluationContext EvaluationContext) error + ShutdownWithContext(ctx context.Context) error +} + // Tracker is the contract for tracking // FeatureProvider can opt in for this behavior by implementing the interface type Tracker interface {