Skip to content

Commit

Permalink
Merge pull request #119762 from AxeZhan/PollUntilContextCancel
Browse files Browse the repository at this point in the history
wait.PollUntilContextCancel immediately executes condition once

Kubernetes-commit: 227d1b2357d93a6884addccb50122df16674ca95
  • Loading branch information
k8s-publishing-bot committed Nov 2, 2023
2 parents 16d50e6 + 5916a9f commit bc0a03b
Show file tree
Hide file tree
Showing 2 changed files with 97 additions and 55 deletions.
38 changes: 18 additions & 20 deletions pkg/util/wait/loop.go
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,10 @@ func loopConditionUntilContext(ctx context.Context, t Timer, immediate, sliding
var timeCh <-chan time.Time
doneCh := ctx.Done()

if !sliding {
timeCh = t.C()
}

// if immediate is true the condition is
// guaranteed to be executed at least once,
// if we haven't requested immediate execution, delay once
Expand All @@ -50,17 +54,27 @@ func loopConditionUntilContext(ctx context.Context, t Timer, immediate, sliding
}(); err != nil || ok {
return err
}
} else {
}

if sliding {
timeCh = t.C()
}

for {

// Wait for either the context to be cancelled or the next invocation be called
select {
case <-doneCh:
return ctx.Err()
case <-timeCh:
}
}

for {
// checking ctx.Err() is slightly faster than checking a select
// IMPORTANT: Because there is no channel priority selection in golang
// it is possible for very short timers to "win" the race in the previous select
// repeatedly even when the context has been canceled. We therefore must
// explicitly check for context cancellation on every loop and exit if true to
// guarantee that we don't invoke condition more than once after context has
// been cancelled.
if err := ctx.Err(); err != nil {
return err
}
Expand All @@ -77,21 +91,5 @@ func loopConditionUntilContext(ctx context.Context, t Timer, immediate, sliding
if sliding {
t.Next()
}

if timeCh == nil {
timeCh = t.C()
}

// NOTE: b/c there is no priority selection in golang
// it is possible for this to race, meaning we could
// trigger t.C and doneCh, and t.C select falls through.
// In order to mitigate we re-check doneCh at the beginning
// of every loop to guarantee at-most one extra execution
// of condition.
select {
case <-doneCh:
return ctx.Err()
case <-timeCh:
}
}
}
114 changes: 79 additions & 35 deletions pkg/util/wait/loop_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -99,6 +99,7 @@ func Test_loopConditionUntilContext_semantic(t *testing.T) {
cancelContextAfter int
attemptsExpected int
errExpected error
timer Timer
}{
{
name: "condition successful is only one attempt",
Expand Down Expand Up @@ -203,45 +204,88 @@ func Test_loopConditionUntilContext_semantic(t *testing.T) {
attemptsExpected: 0,
errExpected: context.DeadlineExceeded,
},
{
name: "context canceled before the second execution and immediate",
immediate: true,
context: func() (context.Context, context.CancelFunc) {
return context.WithTimeout(context.Background(), time.Second)
},
callback: func(attempts int) (bool, error) {
return false, nil
},
attemptsExpected: 1,
errExpected: context.DeadlineExceeded,
timer: Backoff{Duration: 2 * time.Second}.Timer(),
},
{
name: "immediate and long duration of condition and sliding false",
immediate: true,
sliding: false,
context: func() (context.Context, context.CancelFunc) {
return context.WithTimeout(context.Background(), time.Second)
},
callback: func(attempts int) (bool, error) {
if attempts >= 4 {
return true, nil
}
time.Sleep(time.Second / 5)
return false, nil
},
attemptsExpected: 4,
timer: Backoff{Duration: time.Second / 5, Jitter: 0.001}.Timer(),
},
{
name: "immediate and long duration of condition and sliding true",
immediate: true,
sliding: true,
context: func() (context.Context, context.CancelFunc) {
return context.WithTimeout(context.Background(), time.Second)
},
callback: func(attempts int) (bool, error) {
if attempts >= 4 {
return true, nil
}
time.Sleep(time.Second / 5)
return false, nil
},
errExpected: context.DeadlineExceeded,
attemptsExpected: 3,
timer: Backoff{Duration: time.Second / 5, Jitter: 0.001}.Timer(),
},
}

for _, test := range tests {
for _, immediate := range []bool{true, false} {
t.Run(fmt.Sprintf("immediate=%t", immediate), func(t *testing.T) {
for _, sliding := range []bool{true, false} {
t.Run(fmt.Sprintf("sliding=%t", sliding), func(t *testing.T) {
t.Run(test.name, func(t *testing.T) {
contextFn := test.context
if contextFn == nil {
contextFn = defaultContext
}
ctx, cancel := contextFn()
defer cancel()

timer := Backoff{Duration: time.Microsecond}.Timer()
attempts := 0
err := loopConditionUntilContext(ctx, timer, test.immediate, test.sliding, func(_ context.Context) (bool, error) {
attempts++
defer func() {
if test.cancelContextAfter > 0 && test.cancelContextAfter == attempts {
cancel()
}
}()
return test.callback(attempts)
})

if test.errExpected != err {
t.Errorf("expected error: %v but got: %v", test.errExpected, err)
}

if test.attemptsExpected != attempts {
t.Errorf("expected attempts count: %d but got: %d", test.attemptsExpected, attempts)
}
})
})
}
t.Run(test.name, func(t *testing.T) {
contextFn := test.context
if contextFn == nil {
contextFn = defaultContext
}
ctx, cancel := contextFn()
defer cancel()

timer := test.timer
if timer == nil {
timer = Backoff{Duration: time.Microsecond}.Timer()
}
attempts := 0
err := loopConditionUntilContext(ctx, timer, test.immediate, test.sliding, func(_ context.Context) (bool, error) {
attempts++
defer func() {
if test.cancelContextAfter > 0 && test.cancelContextAfter == attempts {
cancel()
}
}()
return test.callback(attempts)
})
}

if test.errExpected != err {
t.Errorf("expected error: %v but got: %v", test.errExpected, err)
}

if test.attemptsExpected != attempts {
t.Errorf("expected attempts count: %d but got: %d", test.attemptsExpected, attempts)
}
})
}
}

Expand Down

0 comments on commit bc0a03b

Please sign in to comment.