diff --git a/Makefile b/Makefile index 67e1439..0cc0a83 100644 --- a/Makefile +++ b/Makefile @@ -14,4 +14,4 @@ lint: go tool govulncheck ./... bench: - go test -run=^$ -bench=. -benchmem ./... + go test -run=^$$ -bench=. -benchmem ./... diff --git a/errors.go b/errors.go index 25e1b51..215b693 100644 --- a/errors.go +++ b/errors.go @@ -60,18 +60,18 @@ type customError struct { basePath string // snapshot of basePath at capture time cause error wrapped error // immediate parent for Unwrap() chain; may differ from cause - shouldNotify bool + shouldNotify atomic.Bool status *grpcstatus.Status } // ShouldNotify returns true if the error should be reported to notifiers. func (c *customError) ShouldNotify() bool { - return c.shouldNotify + return c.shouldNotify.Load() } // Notified marks the error as having been notified (or not). func (c *customError) Notified(status bool) { - c.shouldNotify = !status + c.shouldNotify.Store(!status) } // Error returns the error message. @@ -233,30 +233,30 @@ func WrapWithSkipAndStatus(err error, msg string, skip int, status *grpcstatus.S //if we have stack information reuse that if e, ok := err.(ErrorExt); ok { c := &customError{ - Msg: msg + e.Error(), - cause: e.Cause(), - wrapped: err, // preserve full chain for errors.Is/errors.As - status: status, - shouldNotify: true, + Msg: msg + e.Error(), + cause: e.Cause(), + wrapped: err, // preserve full chain for errors.Is/errors.As + status: status, } + c.shouldNotify.Store(true) c.stack = e.Callers() if ce, ok := e.(*customError); ok { c.basePath = ce.basePath } if n, ok := e.(NotifyExt); ok { - c.shouldNotify = n.ShouldNotify() + c.shouldNotify.Store(n.ShouldNotify()) } return c } c := &customError{ - Msg: msg + err.Error(), - cause: err, - wrapped: err, - shouldNotify: true, - status: status, + Msg: msg + err.Error(), + cause: err, + wrapped: err, + status: status, } + c.shouldNotify.Store(true) c.captureStack(skip + 1) return c diff --git a/notifier/notifier.go b/notifier/notifier.go index 165ec9e..3db5480 100644 --- a/notifier/notifier.go +++ b/notifier/notifier.go @@ -8,6 +8,7 @@ import ( "strconv" "strings" "sync" + "sync/atomic" "time" gobrake "github.com/airbrake/gobrake/v5" @@ -35,12 +36,20 @@ var ( hostname string traceHeader string = "x-trace-id" - // asyncSem is a semaphore that bounds the number of concurrent async - // notification goroutines. When full, new notifications are dropped - // to prevent goroutine explosion under sustained error bursts. - asyncSem = make(chan struct{}, 1000) ) +// asyncSem is a semaphore that bounds the number of concurrent async +// notification goroutines. When full, new notifications are dropped +// to prevent goroutine explosion under sustained error bursts. +// Stored as atomic.Pointer to eliminate the race between SetMaxAsyncNotifications +// and NotifyAsync goroutines reading the channel variable. +var asyncSem atomic.Pointer[chan struct{}] + +func init() { + ch := make(chan struct{}, 20) + asyncSem.Store(&ch) +} + const ( tracerID = "tracerId" ) @@ -50,11 +59,13 @@ var asyncSemOnce sync.Once // SetMaxAsyncNotifications sets the maximum number of concurrent async // notification goroutines. When the limit is reached, new async notifications // are dropped to prevent goroutine explosion under sustained error bursts. -// Default is 1000. Can only be called once; subsequent calls are no-ops. +// Default is 20. The first successful call wins; subsequent calls are no-ops. +// It is safe to call concurrently with NotifyAsync. func SetMaxAsyncNotifications(n int) { if n > 0 { asyncSemOnce.Do(func() { - asyncSem = make(chan struct{}, n) + ch := make(chan struct{}, n) + asyncSem.Store(&ch) }) } } @@ -67,7 +78,7 @@ func NotifyAsync(err error, rawData ...interface{}) error { if err == nil { return nil } - sem := asyncSem + sem := *asyncSem.Load() select { case sem <- struct{}{}: data := append([]interface{}(nil), rawData...) @@ -553,7 +564,9 @@ func SetTraceId(ctx context.Context) context.Context { func GetTraceId(ctx context.Context) string { if o := options.FromContext(ctx); o != nil { if data, found := o.Get(tracerID); found { - return data.(string) + if traceID, ok := data.(string); ok { + return traceID + } } } if logCtx := loggers.FromContext(ctx); logCtx != nil { diff --git a/notifier/notifier_test.go b/notifier/notifier_test.go new file mode 100644 index 0000000..29b1589 --- /dev/null +++ b/notifier/notifier_test.go @@ -0,0 +1,77 @@ +package notifier + +import ( + "context" + "sync" + "testing" + + "github.com/go-coldbrew/errors" + "github.com/go-coldbrew/options" +) + +func TestGetTraceId_NonStringValue(t *testing.T) { + // Regression test: GetTraceId must not panic when the tracerID + // option holds a non-string value. + ctx := options.AddToOptions(context.Background(), tracerID, 12345) + + // Before the fix this panicked with "interface conversion: interface {} is int, not string". + got := GetTraceId(ctx) + if got != "" { + t.Errorf("expected empty string for non-string tracerID, got %q", got) + } +} + +func TestGetTraceId_StringValue(t *testing.T) { + ctx := options.AddToOptions(context.Background(), tracerID, "abc-123") + + got := GetTraceId(ctx) + if got != "abc-123" { + t.Errorf("expected 'abc-123', got %q", got) + } +} + +func TestNotifyAsync_BoundedConcurrency(t *testing.T) { + // Use a 1-slot semaphore and pre-fill it to simulate a full pool. + ch := make(chan struct{}, 1) + ch <- struct{}{} // pre-fill: pool is now full + asyncSem.Store(&ch) + t.Cleanup(func() { + // Restore default. Drain first so cleanup is safe. + for len(ch) > 0 { + <-ch + } + def := make(chan struct{}, 20) + asyncSem.Store(&def) + }) + + // With the semaphore full, NotifyAsync must drop (hit default branch). + // It should not block and should not spawn a goroutine. + NotifyAsync(errors.New("should-drop")) + + // Verify the semaphore is still exactly full (1 token, capacity 1). + // If NotifyAsync had somehow acquired a slot, len would be < cap. + if len(ch) != cap(ch) { + t.Errorf("expected semaphore to remain full (len=%d, cap=%d); NotifyAsync should have dropped", len(ch), cap(ch)) + } +} + +func TestSetMaxAsyncNotifications_ConcurrentAccess(t *testing.T) { + // Regression test: SetMaxAsyncNotifications and NotifyAsync must not + // race on the asyncSem variable. Run with -race to verify. + var wg sync.WaitGroup + + wg.Add(2) + go func() { + defer wg.Done() + for i := 0; i < 100; i++ { + SetMaxAsyncNotifications(50) + } + }() + go func() { + defer wg.Done() + for i := 0; i < 20; i++ { + NotifyAsync(errors.New("race test")) + } + }() + wg.Wait() +}