diff --git a/pkg/util/wait/loop.go b/pkg/util/wait/loop.go index 0dd13c626..107bfc132 100644 --- a/pkg/util/wait/loop.go +++ b/pkg/util/wait/loop.go @@ -40,6 +40,10 @@ func loopConditionUntilContext(ctx context.Context, t Timer, immediate, sliding var timeCh <-chan time.Time doneCh := ctx.Done() + if !sliding { + timeCh = t.C() + } + // if immediate is true the condition is // guaranteed to be executed at least once, // if we haven't requested immediate execution, delay once @@ -50,17 +54,27 @@ func loopConditionUntilContext(ctx context.Context, t Timer, immediate, sliding }(); err != nil || ok { return err } - } else { + } + + if sliding { timeCh = t.C() + } + + for { + + // Wait for either the context to be cancelled or the next invocation be called select { case <-doneCh: return ctx.Err() case <-timeCh: } - } - for { - // checking ctx.Err() is slightly faster than checking a select + // IMPORTANT: Because there is no channel priority selection in golang + // it is possible for very short timers to "win" the race in the previous select + // repeatedly even when the context has been canceled. We therefore must + // explicitly check for context cancellation on every loop and exit if true to + // guarantee that we don't invoke condition more than once after context has + // been cancelled. if err := ctx.Err(); err != nil { return err } @@ -77,21 +91,5 @@ func loopConditionUntilContext(ctx context.Context, t Timer, immediate, sliding if sliding { t.Next() } - - if timeCh == nil { - timeCh = t.C() - } - - // NOTE: b/c there is no priority selection in golang - // it is possible for this to race, meaning we could - // trigger t.C and doneCh, and t.C select falls through. - // In order to mitigate we re-check doneCh at the beginning - // of every loop to guarantee at-most one extra execution - // of condition. - select { - case <-doneCh: - return ctx.Err() - case <-timeCh: - } } } diff --git a/pkg/util/wait/loop_test.go b/pkg/util/wait/loop_test.go index 992d3d04d..63bfa8540 100644 --- a/pkg/util/wait/loop_test.go +++ b/pkg/util/wait/loop_test.go @@ -99,6 +99,7 @@ func Test_loopConditionUntilContext_semantic(t *testing.T) { cancelContextAfter int attemptsExpected int errExpected error + timer Timer }{ { name: "condition successful is only one attempt", @@ -203,45 +204,88 @@ func Test_loopConditionUntilContext_semantic(t *testing.T) { attemptsExpected: 0, errExpected: context.DeadlineExceeded, }, + { + name: "context canceled before the second execution and immediate", + immediate: true, + context: func() (context.Context, context.CancelFunc) { + return context.WithTimeout(context.Background(), time.Second) + }, + callback: func(attempts int) (bool, error) { + return false, nil + }, + attemptsExpected: 1, + errExpected: context.DeadlineExceeded, + timer: Backoff{Duration: 2 * time.Second}.Timer(), + }, + { + name: "immediate and long duration of condition and sliding false", + immediate: true, + sliding: false, + context: func() (context.Context, context.CancelFunc) { + return context.WithTimeout(context.Background(), time.Second) + }, + callback: func(attempts int) (bool, error) { + if attempts >= 4 { + return true, nil + } + time.Sleep(time.Second / 5) + return false, nil + }, + attemptsExpected: 4, + timer: Backoff{Duration: time.Second / 5, Jitter: 0.001}.Timer(), + }, + { + name: "immediate and long duration of condition and sliding true", + immediate: true, + sliding: true, + context: func() (context.Context, context.CancelFunc) { + return context.WithTimeout(context.Background(), time.Second) + }, + callback: func(attempts int) (bool, error) { + if attempts >= 4 { + return true, nil + } + time.Sleep(time.Second / 5) + return false, nil + }, + errExpected: context.DeadlineExceeded, + attemptsExpected: 3, + timer: Backoff{Duration: time.Second / 5, Jitter: 0.001}.Timer(), + }, } for _, test := range tests { - for _, immediate := range []bool{true, false} { - t.Run(fmt.Sprintf("immediate=%t", immediate), func(t *testing.T) { - for _, sliding := range []bool{true, false} { - t.Run(fmt.Sprintf("sliding=%t", sliding), func(t *testing.T) { - t.Run(test.name, func(t *testing.T) { - contextFn := test.context - if contextFn == nil { - contextFn = defaultContext - } - ctx, cancel := contextFn() - defer cancel() - - timer := Backoff{Duration: time.Microsecond}.Timer() - attempts := 0 - err := loopConditionUntilContext(ctx, timer, test.immediate, test.sliding, func(_ context.Context) (bool, error) { - attempts++ - defer func() { - if test.cancelContextAfter > 0 && test.cancelContextAfter == attempts { - cancel() - } - }() - return test.callback(attempts) - }) - - if test.errExpected != err { - t.Errorf("expected error: %v but got: %v", test.errExpected, err) - } - - if test.attemptsExpected != attempts { - t.Errorf("expected attempts count: %d but got: %d", test.attemptsExpected, attempts) - } - }) - }) - } + t.Run(test.name, func(t *testing.T) { + contextFn := test.context + if contextFn == nil { + contextFn = defaultContext + } + ctx, cancel := contextFn() + defer cancel() + + timer := test.timer + if timer == nil { + timer = Backoff{Duration: time.Microsecond}.Timer() + } + attempts := 0 + err := loopConditionUntilContext(ctx, timer, test.immediate, test.sliding, func(_ context.Context) (bool, error) { + attempts++ + defer func() { + if test.cancelContextAfter > 0 && test.cancelContextAfter == attempts { + cancel() + } + }() + return test.callback(attempts) }) - } + + if test.errExpected != err { + t.Errorf("expected error: %v but got: %v", test.errExpected, err) + } + + if test.attemptsExpected != attempts { + t.Errorf("expected attempts count: %d but got: %d", test.attemptsExpected, attempts) + } + }) } }