From 41fab10f79f591b18c1e8c04fab75c1c12d08ec0 Mon Sep 17 00:00:00 2001 From: Thomas Stromberg Date: Mon, 28 Jul 2025 14:52:37 -0400 Subject: [PATCH] tests: add edge cases, improve TestComments compliance --- retry_test.go | 540 +++++++++++++++++++++++++++++++++++++++++++++----- 1 file changed, 487 insertions(+), 53 deletions(-) diff --git a/retry_test.go b/retry_test.go index 8f965f5..fa5d851 100644 --- a/retry_test.go +++ b/retry_test.go @@ -7,6 +7,7 @@ import ( "math" "os" "reflect" + "sync" "testing" "time" ) @@ -22,7 +23,7 @@ func TestDoWithDataAllFailed(t *testing.T) { t.Fatal("expected error, got nil") } if v != 0 { - t.Errorf("got v=%d, want 0", v) + t.Errorf("returned value: got %d, want 0", v) } expectedErrorFormat := `All attempts fail: @@ -47,7 +48,7 @@ func TestDoWithDataAllFailed(t *testing.T) { if err.Error() != expectedErrorFormat { t.Errorf("error message: got %q, want %q", err.Error(), expectedErrorFormat) } - if retrySum != uint(36) { + if retrySum != 36 { t.Errorf("retry sum: got %d, want 36", retrySum) } } @@ -61,8 +62,8 @@ func TestDoFirstOk(t *testing.T) { if err != nil { t.Fatalf("unexpected error: %v", err) } - if uint(0) != retrySum { - t.Errorf("retrySum (expected no retries): got %d, want 0", retrySum) + if retrySum != 0 { + t.Errorf("retrySum: got %d, want 0 (no retries expected)", retrySum) } } @@ -77,11 +78,11 @@ func TestDoWithDataFirstOk(t *testing.T) { if err != nil { t.Fatalf("unexpected error: %v", err) } - if returnVal != val { + if val != returnVal { t.Errorf("return value: got %d, want %d", val, returnVal) } - if uint(0) != retrySum { - t.Errorf("retrySum (expected no retries): got %d, want 0", retrySum) + if retrySum != 0 { + t.Errorf("retrySum: got %d, want 0 (no retries expected)", retrySum) } } @@ -119,7 +120,7 @@ func TestRetryIf(t *testing.T) { if err.Error() != expectedErrorFormat { t.Errorf("error message: got %q, want %q", err.Error(), expectedErrorFormat) } - if uint(2) != retryCount { + if retryCount != 2 { t.Errorf("retry count: got %d, want 2", retryCount) } } @@ -147,7 +148,7 @@ func TestRetryIf_ZeroAttempts(t *testing.T) { t.Fatal("expected error, got nil") } - if "special" != err.Error() { + if err.Error() != "special" { t.Errorf("error message: got %q, want %q", err.Error(), "special") } if retryCount != onRetryCount+1 { @@ -175,7 +176,7 @@ func TestZeroAttemptsWithError(t *testing.T) { t.Fatalf("unexpected error: %v", err) } - if maxErrors != count { + if count != maxErrors { t.Errorf("execution count: got %d, want %d", count, maxErrors) } } @@ -195,7 +196,7 @@ func TestZeroAttemptsWithoutError(t *testing.T) { t.Fatalf("unexpected error: %v", err) } - if 1 != count { + if count != 1 { t.Errorf("execution count: got %d, want 1", count) } } @@ -213,7 +214,7 @@ func TestZeroAttemptsWithUnrecoverableError(t *testing.T) { } expectedErr := Unrecoverable(errors.New("test error")) if err.Error() != expectedErr.Error() { - t.Errorf("got %v, want %v", err, expectedErr) + t.Errorf("error: got %v, want %v", err, expectedErr) } } @@ -421,11 +422,11 @@ func TestBackOffDelay(t *testing.T) { delay: c.delay, } delay := BackOffDelay(c.n, nil, &config) - if c.expectedMaxN != config.maxBackOffN { - t.Errorf("max n mismatch: got %v, want %v", config.maxBackOffN, c.expectedMaxN) + if config.maxBackOffN != c.expectedMaxN { + t.Errorf("max n: got %v, want %v", config.maxBackOffN, c.expectedMaxN) } - if c.expectedDelay != delay { - t.Errorf("delay duration mismatch: got %v, want %v", delay, c.expectedDelay) + if delay != c.expectedDelay { + t.Errorf("delay duration: got %v, want %v", delay, c.expectedDelay) } }, ) @@ -480,8 +481,8 @@ func TestCombineDelay(t *testing.T) { funcs[i] = f(d) } actual := CombineDelay(funcs...)(0, nil, nil) - if c.expected != actual { - t.Errorf("delay duration mismatch: got %v, want %v", actual, c.expected) + if actual != c.expected { + t.Errorf("delay duration: got %v, want %v", actual, c.expected) } }, ) @@ -505,8 +506,8 @@ func TestContext(t *testing.T) { if err == nil { t.Fatal("expected error, got nil") } - if !(dur < defaultDelay) { - t.Errorf("cancellation timing: got duration=%v, want = defaultDelay { + t.Errorf("cancellation timing: got %v, want <%v", dur, defaultDelay) } if retrySum != 0 { t.Errorf("retry count: got %d, want 0", retrySum) @@ -537,12 +538,12 @@ func TestContext(t *testing.T) { #3: context canceled` if retryErr, ok := err.(Error); ok { if len(retryErr) != 3 { - t.Errorf("expected len=%d, got %d", 3, len(retryErr)) + t.Errorf("error count: got %d, want 3", len(retryErr)) } } else { t.Fatalf("expected Error type, got %T", err) } - if expectedErrorFormat != err.Error() { + if err.Error() != expectedErrorFormat { t.Errorf("error message: got %q, want %q", err.Error(), expectedErrorFormat) } if retrySum != 2 { @@ -565,7 +566,7 @@ func TestContext(t *testing.T) { Context(ctx), LastErrorOnly(true), ) - if context.Canceled != err { + if err != context.Canceled { t.Errorf("error: got %v, want %v", err, context.Canceled) } @@ -593,7 +594,7 @@ func TestContext(t *testing.T) { Attempts(0), ) - if context.Canceled != err { + if err != context.Canceled { t.Errorf("error: got %v, want %v", err, context.Canceled) } @@ -651,29 +652,30 @@ func TestContext(t *testing.T) { }) } -type testTimer struct { - called bool -} - -func (t *testTimer) After(d time.Duration) <-chan time.Time { - t.called = true - return time.After(d) -} - func TestTimerInterface(t *testing.T) { - var timer testTimer + timer := &testTimer{} + attempts := 0 err := Do( - func() error { return errors.New("test") }, - Attempts(1), + func() error { + attempts++ + if attempts < 2 { + return errors.New("test") + } + return nil + }, + Attempts(3), Delay(10*time.Millisecond), MaxDelay(50*time.Millisecond), - WithTimer(&timer), + WithTimer(timer), ) - if err == nil { - t.Fatal("expected error, got nil") + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + + if !timer.called { + t.Error("expected timer.After to be called") } - } func TestErrorIs(t *testing.T) { @@ -684,13 +686,13 @@ func TestErrorIs(t *testing.T) { e = append(e, closedErr) if !errors.Is(e, expectErr) { - t.Error("IsRecoverable(err) = false, want true") + t.Error("errors.Is(e, expectErr): got false, want true") } if !errors.Is(e, closedErr) { - t.Error("IsRecoverable(err) = false, want true") + t.Error("errors.Is(e, closedErr): got false, want true") } if errors.Is(e, errors.New("error")) { - t.Error("IsRecoverable(err) = true, want false") + t.Error("errors.Is(e, new error): got true, want false") } } @@ -715,12 +717,12 @@ func TestErrorAs(t *testing.T) { var tb barErr if !errors.As(e, &tf) { - t.Error("IsRecoverable(err) = false, want true") + t.Error("errors.As(e, &fooErr): got false, want true") } if errors.As(e, &tb) { - t.Error("IsRecoverable(err) = true, want false") + t.Error("errors.As(e, &barErr): got true, want false") } - if "foo" != tf.str { + if tf.str != "foo" { t.Errorf("fooErr.str: got %q, want %q", tf.str, "foo") } } @@ -737,7 +739,7 @@ func TestUnwrap(t *testing.T) { if err == nil { t.Fatal("expected error, got nil") } - if testError != errors.Unwrap(err) { + if errors.Unwrap(err) != testError { t.Errorf("unwrapped error: got %v, want %v", errors.Unwrap(err), testError) } } @@ -823,7 +825,7 @@ func TestAttemptsForErrorNoDelayAfterFinalAttempt(t *testing.T) { if err == nil { t.Fatal("expected error, got nil") } - if uint64(2) != count { + if count != 2 { t.Errorf("attempt count: got %d, want 2", count) } if len(timestamps) != 2 { @@ -887,20 +889,452 @@ func TestOnRetryNotCalledOnLastAttempt(t *testing.T) { func TestIsRecoverable(t *testing.T) { err := errors.New("err") if !IsRecoverable(err) { - t.Error("IsRecoverable(err) = false, want true") + t.Error("IsRecoverable(err): got false, want true") } err = Unrecoverable(err) if IsRecoverable(err) { - t.Error("IsRecoverable(err) = true, want false") + t.Error("IsRecoverable(unrecoverable err): got true, want false") } err = fmt.Errorf("wrapping: %w", err) if IsRecoverable(err) { - t.Error("IsRecoverable(err) = true, want false") + t.Error("IsRecoverable(wrapped unrecoverable): got true, want false") + } +} + +func TestPanicRecovery(t *testing.T) { + t.Run("panic in retryable function", func(t *testing.T) { + defer func() { + if r := recover(); r != nil { + // Good - panic was not swallowed + } else { + t.Error("expected panic to propagate, but it was swallowed") + } + }() + + err := Do(func() error { + panic("test panic") + }) + // Should not reach here + t.Errorf("expected panic, got error: %v", err) + }) + + t.Run("panic in OnRetry callback", func(t *testing.T) { + defer func() { + if r := recover(); r != nil { + // Good - panic was not swallowed + } else { + t.Error("expected panic to propagate from OnRetry") + } + }() + + err := Do( + func() error { return errors.New("test") }, + OnRetry(func(n uint, err error) { + panic("panic in callback") + }), + Attempts(2), + ) + t.Errorf("expected panic, got error: %v", err) + }) + + t.Run("panic in DelayType function", func(t *testing.T) { + defer func() { + if r := recover(); r != nil { + // Good - panic was not swallowed + } else { + t.Error("expected panic to propagate from DelayType") + } + }() + + err := Do( + func() error { return errors.New("test") }, + DelayType(func(n uint, err error, config *Config) time.Duration { + panic("panic in delay calculation") + }), + Attempts(2), + ) + t.Errorf("expected panic, got error: %v", err) + }) +} + +func TestContextWithCustomCause(t *testing.T) { + customErr := errors.New("custom cancellation reason") + ctx, cancel := context.WithCancelCause(context.Background()) + + go func() { + time.Sleep(50 * time.Millisecond) + cancel(customErr) + }() + + err := Do( + func() error { + time.Sleep(100 * time.Millisecond) + return errors.New("test") + }, + Context(ctx), + Attempts(5), + ) + + if !errors.Is(err, customErr) { + t.Errorf("expected custom cancellation cause in error chain, got: %v", err) + } +} + +func TestConcurrentRetryUsage(t *testing.T) { + // Test that retry is safe for concurrent use + var wg sync.WaitGroup + goroutines := 20 // Reduced from 100 for faster tests + + for i := 0; i < goroutines; i++ { + wg.Add(1) + go func(id int) { + defer wg.Done() + + count := 0 + err := Do( + func() error { + count++ + if count < 3 { + return fmt.Errorf("error from goroutine %d", id) + } + return nil + }, + Attempts(5), + Delay(0), // No delay for speed + ) + + if err != nil { + t.Errorf("goroutine %d: unexpected error: %v", id, err) + } + if count != 3 { + t.Errorf("goroutine %d: expected 3 attempts, got %d", id, count) + } + }(i) + } + + wg.Wait() +} + +func TestDoWithDataGenericEdgeCases(t *testing.T) { + t.Run("nil pointer return", func(t *testing.T) { + result, err := DoWithData(func() (*string, error) { + return nil, nil + }) + if err != nil { + t.Errorf("unexpected error: %v", err) + } + if result != nil { + t.Errorf("expected nil result, got: %v", result) + } + }) + + t.Run("interface{} return type", func(t *testing.T) { + expected := map[string]interface{}{ + "key": "value", + "num": 42, + } + result, err := DoWithData(func() (interface{}, error) { + return expected, nil + }) + if err != nil { + t.Errorf("unexpected error: %v", err) + } + if !reflect.DeepEqual(result, expected) { + t.Errorf("result: got %v, want %v", result, expected) + } + }) + + t.Run("channel type", func(t *testing.T) { + ch := make(chan int, 1) + ch <- 42 + + result, err := DoWithData(func() (chan int, error) { + return ch, nil + }) + if err != nil { + t.Errorf("unexpected error: %v", err) + } + if result != ch { + t.Errorf("expected same channel, got different channel") + } + }) + + t.Run("function type", func(t *testing.T) { + fn := func(x int) int { return x * 2 } + + result, err := DoWithData(func() (func(int) int, error) { + return fn, nil + }) + if err != nil { + t.Errorf("unexpected error: %v", err) + } + // Can't compare functions directly, but we can test behavior + if result(21) != 42 { + t.Errorf("returned function behavior differs") + } + }) +} + +func TestVeryLargeDelayOverflow(t *testing.T) { + // Test with delays near MaxInt64 + largeDelay := time.Duration(math.MaxInt64) - time.Hour + + ctx, cancel := context.WithTimeout(context.Background(), 100*time.Millisecond) + defer cancel() + + err := Do( + func() error { return errors.New("test") }, + Context(ctx), + Delay(largeDelay), + Attempts(2), + DelayType(func(n uint, err error, config *Config) time.Duration { + // Try to cause overflow + return config.delay + time.Hour + }), + ) + + // Should timeout, not panic or hang + if !errors.Is(err, context.DeadlineExceeded) { + t.Errorf("expected deadline exceeded, got: %v", err) } } +func TestErrorAccumulationAtCapacity(t *testing.T) { + // Test that error accumulation is capped to prevent unbounded memory growth + // We verify the error slice capacity is pre-allocated and capped correctly + + testCases := []struct { + name string + attempts uint + expectedCap int + expectedLen int + }{ + {"small attempts", 10, 10, 10}, + {"medium attempts", 100, 100, 100}, + {"at cap", 1000, 1000, 1000}, + {"over cap", 1500, 1000, 50}, // Run only 50 attempts to avoid timeout + } + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + attempts := 0 + err := Do( + func() error { + attempts++ + // For large attempt counts, stop early to avoid timeout + if tc.attempts >= 1000 && attempts >= tc.expectedLen { + return nil + } + return fmt.Errorf("error %d", attempts) + }, + Attempts(tc.attempts), + Delay(0), + DelayType(FixedDelay), + ) + + // Should have succeeded on large tests + if tc.attempts >= 1000 && err != nil { + if errList, ok := err.(Error); ok { + // Verify capacity is capped + if cap(errList) != tc.expectedCap { + t.Errorf("error slice capacity: got %d, want %d", cap(errList), tc.expectedCap) + } + } + } else if tc.attempts < 1000 { + // Small tests should fail all attempts + if err == nil { + t.Fatal("expected error, got nil") + } + errList, ok := err.(Error) + if !ok { + t.Fatalf("expected Error type, got %T", err) + } + if cap(errList) != tc.expectedCap { + t.Errorf("error slice capacity: got %d, want %d", cap(errList), tc.expectedCap) + } + } + }) + } +} + +func TestRetryIfWithChangingConditions(t *testing.T) { + // Test RetryIf function that changes behavior based on external state + var shouldRetry bool = true + attempts := 0 + + err := Do( + func() error { + attempts++ + if attempts == 3 { + shouldRetry = false // Change condition mid-retry + } + return errors.New("test error") + }, + RetryIf(func(err error) bool { + return shouldRetry + }), + Attempts(10), + Delay(time.Millisecond), + ) + + if err == nil { + t.Fatal("expected error, got nil") + } + + // Should stop at attempt 3 + // Attempt 1: error, shouldRetry=true, retryIf returns true → continue + // Attempt 2: error, shouldRetry=true, retryIf returns true → continue + // Attempt 3: sets shouldRetry=false, error, retryIf returns false → stop + if attempts != 3 { + t.Errorf("expected 3 attempts, got %d", attempts) + } +} + +func TestCustomTimerEdgeCases(t *testing.T) { + t.Run("timer returns closed channel", func(t *testing.T) { + closedCh := make(chan time.Time) + close(closedCh) + + timer := &testTimer{ + afterFunc: func(d time.Duration) <-chan time.Time { + return closedCh + }, + } + + attempts := 0 + err := Do( + func() error { + attempts++ + if attempts < 3 { + return errors.New("test") + } + return nil + }, + WithTimer(timer), + Attempts(5), + ) + + if err != nil { + t.Errorf("unexpected error: %v", err) + } + if attempts != 3 { + t.Errorf("expected 3 attempts, got %d", attempts) + } + }) +} + +func TestComplexErrorChains(t *testing.T) { + // Create a complex error chain + baseErr := errors.New("base error") + wrappedOnce := fmt.Errorf("wrapped once: %w", baseErr) + wrappedTwice := fmt.Errorf("wrapped twice: %w", wrappedOnce) + customErr := &fooErr{str: "custom error"} + wrappedCustom := fmt.Errorf("wrapped custom: %w", customErr) + + attempts := 0 + err := Do( + func() error { + attempts++ + switch attempts { + case 1: + return wrappedTwice + case 2: + return wrappedCustom + case 3: + return Unrecoverable(wrappedOnce) + default: + return nil + } + }, + Attempts(5), + ) + + if err == nil { + t.Fatal("expected error, got nil") + } + + // Verify error chain contains all expected errors + if !errors.Is(err, baseErr) { + t.Error("error chain should contain base error") + } + + // Check if err contains fooErr + // The second attempt returns wrappedCustom which contains &fooErr + var fe fooErr + found := false + + // Check if we can find it directly + if errors.As(err, &fe) { + found = true + } else if errList, ok := err.(Error); ok { + // Check each error in the list + for _, e := range errList { + var tempFe *fooErr + if errors.As(e, &tempFe) { + found = true + break + } + } + } + + if !found { + t.Error("error chain should contain fooErr") + } + + // Should stop at unrecoverable + if attempts != 3 { + t.Errorf("expected 3 attempts (stopped at unrecoverable), got %d", attempts) + } +} + +func TestRetryWithNilContext(t *testing.T) { + // Even though we validate context isn't nil, test defensive programming + defer func() { + if r := recover(); r != nil { + t.Logf("recovered from panic as expected: %v", r) + } + }() + + config := &Config{ + attempts: 3, + delay: time.Millisecond, + retryIf: IsRecoverable, + delayType: FixedDelay, + timer: &timerImpl{}, + context: nil, // Intentionally nil + onRetry: func(n uint, err error) {}, + } + + // This should be caught by validate() + retryableFunc := func() (interface{}, error) { + return nil, errors.New("test") + } + + _, err := DoWithData(retryableFunc, func(c *Config) { + *c = *config + }) + + if err == nil || err.Error() != "context cannot be nil" { + t.Errorf("expected context validation error, got: %v", err) + } +} + +// Update testTimer to support custom behavior +type testTimer struct { + called bool + afterFunc func(time.Duration) <-chan time.Time +} + +func (t *testTimer) After(d time.Duration) <-chan time.Time { + t.called = true + if t.afterFunc != nil { + return t.afterFunc(d) + } + return time.After(d) +} + func TestFullJitterBackoffDelay(t *testing.T) { // Seed for predictable randomness in tests // In real usage, math/rand is auto-seeded in Go 1.20+ or should be seeded once at program start. @@ -957,12 +1391,12 @@ func TestFullJitterBackoffDelay(t *testing.T) { // Test case where baseDelay might be zero configZeroBase := &Config{delay: 0, maxDelay: maxDelay} delayZeroBase := FullJitterBackoffDelay(0, errors.New("test error"), configZeroBase) - if time.Duration(0) != delayZeroBase { + if delayZeroBase != 0 { t.Errorf("delay with zero base: got %v, want 0", delayZeroBase) } delayZeroBaseAttempt1 := FullJitterBackoffDelay(1, errors.New("test error"), configZeroBase) - if time.Duration(0) != delayZeroBaseAttempt1 { + if delayZeroBaseAttempt1 != 0 { t.Errorf("delay with zero base (attempt>0): got %v, want 0", delayZeroBaseAttempt1) }